diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a12f29..fce8bea 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,22 +7,19 @@ project(CUDANet find_package(CUDAToolkit REQUIRED) include_directories(${CUDAToolkit_INCLUDE_DIRS}) +file(GLOB_RECURSE LIBRARY_SOURCES + src/*.cu + src/utils/*.cu + src/kernels/*.cu + src/layers/*.cu) + set(LIBRARY_SOURCES - src/utils/cuda_helper.cu - src/kernels/activation_functions.cu - src/kernels/convolution.cu - src/kernels/matmul.cu - src/layers/add.cu - src/layers/dense.cu - src/layers/conv2d.cu - src/layers/concat.cu - src/layers/input.cu - src/layers/activation.cu + ${LIBRARY_SOURCES} ) set(CMAKE_CUDA_ARCHITECTURES 75) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -arch=sm_75) +# set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -arch=sm_75) # Build static library add_library(${PROJECT_NAME} STATIC ${LIBRARY_SOURCES}) diff --git a/include/layers/activation.cuh b/include/layers/activation.cuh index 3a407e6..66f492f 100644 --- a/include/layers/activation.cuh +++ b/include/layers/activation.cuh @@ -13,6 +13,10 @@ namespace CUDANet::Layers { */ enum ActivationType { SIGMOID, RELU, SOFTMAX, NONE }; +/** + * @brief Utility class that performs activation + * + */ class Activation { public: diff --git a/include/layers/add.cuh b/include/layers/add.cuh index 653b699..5efb545 100644 --- a/include/layers/add.cuh +++ b/include/layers/add.cuh @@ -19,14 +19,13 @@ class Add { ~Add(); /** - * @brief Adds the two inputs + * @brief Adds first input to second input * * @param d_inputA Device pointer to the first input * @param d_inputB Device pointer to the second input * - * @return Device pointer to the output */ - float* forward(const float* d_inputA, const float* d_inputB); + void forward(const float* d_inputA, const float* d_inputB); private: int inputSize; diff --git a/include/layers/conv2d.cuh b/include/layers/conv2d.cuh index b74109c..40ad3c5 100644 --- a/include/layers/conv2d.cuh +++ b/include/layers/conv2d.cuh @@ -6,7 +6,7 @@ #include "activation.cuh" #include "convolution.cuh" -#include "weighted_layer.cuh" +#include "layer.cuh" namespace CUDANet::Layers { diff --git a/include/layers/dense.cuh b/include/layers/dense.cuh index 7bb4a71..3922a3f 100644 --- a/include/layers/dense.cuh +++ b/include/layers/dense.cuh @@ -5,7 +5,7 @@ #include #include -#include "weighted_layer.cuh" +#include "layer.cuh" namespace CUDANet::Layers { diff --git a/include/layers/weighted_layer.cuh b/include/layers/layer.cuh similarity index 58% rename from include/layers/weighted_layer.cuh rename to include/layers/layer.cuh index f538dad..3c1be19 100644 --- a/include/layers/weighted_layer.cuh +++ b/include/layers/layer.cuh @@ -8,7 +8,7 @@ namespace CUDANet::Layers { /** * @brief Padding types - * + * * SAME: Zero padding such that the output size is the same as the input * VALID: No padding * @@ -16,40 +16,60 @@ namespace CUDANet::Layers { enum Padding { SAME, VALID }; /** - * @brief Base class for all layers + * @brief Basic Sequential Layer + * */ -class WeightedLayer { +class SequentialLayer { + public: + /** + * @brief Destroy the Sequential Layer + * + */ + virtual ~SequentialLayer() {}; + + /** + * @brief Forward propagation virtual function + * + * @param input Device pointer to the input + * @return float* Device pointer to the output + */ + virtual float* forward(const float* input) = 0; +}; + +/** + * @brief Base class for layers with weights and biases + */ +class WeightedLayer : public SequentialLayer { public: /** * @brief Destroy the ILayer object - * + * */ - virtual ~WeightedLayer() {} + virtual ~WeightedLayer() {}; /** * @brief Virtual function for forward pass - * + * * @param input (Device) Pointer to the input * @return float* Device pointer to the output */ - virtual float* forward(const float* input) = 0; + virtual float* forward(const float* input) = 0; /** * @brief Virtual function for setting weights - * + * * @param weights Pointer to the weights */ - virtual void setWeights(const float* weights) = 0; + virtual void setWeights(const float* weights) = 0; /** * @brief Virtual function for setting biases - * + * * @param biases Pointer to the biases */ - virtual void setBiases(const float* biases) = 0; + virtual void setBiases(const float* biases) = 0; private: - /** * @brief Initialize the weights */ @@ -58,7 +78,7 @@ class WeightedLayer { /** * @brief Initialize the biases */ - virtual void initializeBiases() = 0; + virtual void initializeBiases() = 0; /** * @brief Copy the weights and biases to the device diff --git a/src/layers/activation.cu b/src/layers/activation.cu index 7ceb1b8..3f70f38 100644 --- a/src/layers/activation.cu +++ b/src/layers/activation.cu @@ -3,9 +3,9 @@ #include "cuda_helper.cuh" #include "activation_functions.cuh" -using namespace CUDANet; +using namespace CUDANet::Layers; -Layers::Activation::Activation(ActivationType activation, const unsigned int length) +Activation::Activation(ActivationType activation, const unsigned int length) : activationType(activation), length(length) { if (activationType == SOFTMAX) { @@ -16,13 +16,13 @@ Layers::Activation::Activation(ActivationType activation, const unsigned int len gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE; } -Layers::Activation::~Activation() { +Activation::~Activation() { if (activationType == SOFTMAX) { cudaFree(d_softmax_sum); } } -void Layers::Activation::activate(float* __restrict__ d_input) { +void Activation::activate(float* __restrict__ d_input) { switch (activationType) { case SIGMOID: diff --git a/src/layers/add.cu b/src/layers/add.cu index 4b528a9..7539e5f 100644 --- a/src/layers/add.cu +++ b/src/layers/add.cu @@ -2,10 +2,10 @@ #include "matmul.cuh" #include "cuda_helper.cuh" -using namespace CUDANet; +using namespace CUDANet::Layers; -Layers::Add::Add(int inputSize) +Add::Add(int inputSize) : inputSize(inputSize) { d_output = nullptr; @@ -15,12 +15,12 @@ Layers::Add::Add(int inputSize) } -Layers::Add::~Add() { +Add::~Add() { cudaFree(d_output); } -float* Layers::Add::forward(const float* d_inputA, const float* d_inputB) { +void Add::forward(const float* d_inputA, const float* d_inputB) { Kernels::vec_vec_add<<>>( d_inputA, d_inputB, d_output, inputSize diff --git a/src/layers/concat.cu b/src/layers/concat.cu index 127017d..94284e6 100644 --- a/src/layers/concat.cu +++ b/src/layers/concat.cu @@ -1,10 +1,10 @@ #include "concat.cuh" #include "cuda_helper.cuh" -using namespace CUDANet; +using namespace CUDANet::Layers; -Layers::Concat::Concat(const unsigned int inputASize, const unsigned int inputBSize) +Concat::Concat(const unsigned int inputASize, const unsigned int inputBSize) : inputASize(inputASize), inputBSize(inputBSize) { d_output = nullptr; @@ -14,12 +14,12 @@ Layers::Concat::Concat(const unsigned int inputASize, const unsigned int inputBS } -Layers::Concat::~Concat() { +Concat::~Concat() { cudaFree(d_output); } -float* Layers::Concat::forward(const float* d_input_A, const float* d_input_B) { +float* Concat::forward(const float* d_input_A, const float* d_input_B) { CUDA_CHECK(cudaMemcpy( d_output, d_input_A, sizeof(float) * inputASize, cudaMemcpyDeviceToDevice )); diff --git a/src/layers/conv2d.cu b/src/layers/conv2d.cu index 719ec9e..79e3c88 100644 --- a/src/layers/conv2d.cu +++ b/src/layers/conv2d.cu @@ -7,16 +7,16 @@ #include "cuda_helper.cuh" #include "matmul.cuh" -using namespace CUDANet; +using namespace CUDANet::Layers; -Layers::Conv2d::Conv2d( +Conv2d::Conv2d( int inputSize, int inputChannels, int kernelSize, int stride, int numFilters, - Layers::Padding padding, - Layers::ActivationType activationType + Padding padding, + ActivationType activationType ) : inputSize(inputSize), inputChannels(inputChannels), @@ -68,31 +68,31 @@ Layers::Conv2d::Conv2d( toCuda(); } -Layers::Conv2d::~Conv2d() { +Conv2d::~Conv2d() { cudaFree(d_output); cudaFree(d_weights); cudaFree(d_biases); } -void Layers::Conv2d::initializeWeights() { +void Conv2d::initializeWeights() { std::fill(weights.begin(), weights.end(), 0.0f); } -void Layers::Conv2d::initializeBiases() { +void Conv2d::initializeBiases() { std::fill(biases.begin(), biases.end(), 0.0f); } -void Layers::Conv2d::setWeights(const float* weights_input) { +void Conv2d::setWeights(const float* weights_input) { std::copy(weights_input, weights_input + weights.size(), weights.begin()); toCuda(); } -void Layers::Conv2d::setBiases(const float* biases_input) { +void Conv2d::setBiases(const float* biases_input) { std::copy(biases_input, biases_input + biases.size(), biases.begin()); toCuda(); } -void Layers::Conv2d::toCuda() { +void Conv2d::toCuda() { CUDA_CHECK(cudaMemcpy( d_weights, weights.data(), sizeof(float) * kernelSize * kernelSize * inputChannels * numFilters, @@ -106,7 +106,7 @@ void Layers::Conv2d::toCuda() { )); } -float* Layers::Conv2d::forward(const float* d_input) { +float* Conv2d::forward(const float* d_input) { // Convolve int THREADS_PER_BLOCK = outputSize * outputSize * numFilters; Kernels::convolution<<<1, THREADS_PER_BLOCK>>>( diff --git a/src/layers/dense.cu b/src/layers/dense.cu index 11896da..751b5a9 100644 --- a/src/layers/dense.cu +++ b/src/layers/dense.cu @@ -10,19 +10,19 @@ #include "dense.cuh" #include "matmul.cuh" -using namespace CUDANet; +using namespace CUDANet::Layers; -Layers::Dense::Dense( +Dense::Dense( int inputSize, int outputSize, - Layers::ActivationType activationType + ActivationType activationType ) : inputSize(inputSize), outputSize(outputSize) { // Allocate memory for weights and biases weights.resize(outputSize * inputSize); biases.resize(outputSize); - activation = Layers::Activation(activationType, outputSize); + activation = Activation(activationType, outputSize); initializeWeights(); initializeBiases(); @@ -47,22 +47,22 @@ Layers::Dense::Dense( biasGridSize = (outputSize + BLOCK_SIZE - 1) / BLOCK_SIZE; } -Layers::Dense::~Dense() { +Dense::~Dense() { // Free GPU memory cudaFree(d_output); cudaFree(d_weights); cudaFree(d_biases); } -void Layers::Dense::initializeWeights() { +void Dense::initializeWeights() { std::fill(weights.begin(), weights.end(), 0.0f); } -void Layers::Dense::initializeBiases() { +void Dense::initializeBiases() { std::fill(biases.begin(), biases.end(), 0.0f); } -float* Layers::Dense::forward(const float* d_input) { +float* Dense::forward(const float* d_input) { Kernels::mat_vec_mul<<>>( d_weights, d_input, d_output, inputSize, outputSize ); @@ -78,7 +78,7 @@ float* Layers::Dense::forward(const float* d_input) { return d_output; } -void Layers::Dense::toCuda() { +void Dense::toCuda() { CUDA_CHECK(cudaMemcpy( d_weights, weights.data(), sizeof(float) * inputSize * outputSize, cudaMemcpyHostToDevice @@ -89,12 +89,12 @@ void Layers::Dense::toCuda() { )); } -void Layers::Dense::setWeights(const float* weights_input) { +void Dense::setWeights(const float* weights_input) { std::copy(weights_input, weights_input + weights.size(), weights.begin()); toCuda(); } -void Layers::Dense::setBiases(const float* biases_input) { +void Dense::setBiases(const float* biases_input) { std::copy(biases_input, biases_input + biases.size(), biases.begin()); toCuda(); } \ No newline at end of file diff --git a/src/layers/input.cu b/src/layers/input.cu index e8015a2..2e5f157 100644 --- a/src/layers/input.cu +++ b/src/layers/input.cu @@ -1,14 +1,14 @@ #include "cuda_helper.cuh" #include "input.cuh" -using namespace CUDANet; +using namespace CUDANet::Layers; -Layers::Input::Input(int inputSize) : inputSize(inputSize) { +Input::Input(int inputSize) : inputSize(inputSize) { d_output = nullptr; CUDA_CHECK(cudaMalloc((void**)&d_output, sizeof(float) * inputSize)); } -Layers::Input::~Input() { +Input::~Input() { cudaFree(d_output); } @@ -19,7 +19,7 @@ Args const float* input Host pointer to input data float* d_output Device pointer to input data copied to device */ -float* Layers::Input::forward(const float* input) { +float* Input::forward(const float* input) { CUDA_CHECK(cudaMemcpy( d_output, input, sizeof(float) * inputSize, cudaMemcpyHostToDevice ));