From aac0c3a82668d5e3333677a2f26beee71403c7f1 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sun, 17 Mar 2024 21:38:29 +0100 Subject: [PATCH] Implement concat layer --- CMakeLists.txt | 1 + README.md | 2 +- include/layers/concat.cuh | 45 ++++++++++++++++++++++++++++++++++++++ include/layers/ilayer.cuh | 12 ---------- include/layers/input.cuh | 10 +-------- src/layers/concat.cu | 32 +++++++++++++++++++++++++++ src/layers/input.cu | 8 ------- test/CMakeLists.txt | 7 +++--- test/layers/test_concat.cu | 37 +++++++++++++++++++++++++++++++ 9 files changed, 121 insertions(+), 33 deletions(-) create mode 100644 include/layers/concat.cuh create mode 100644 src/layers/concat.cu create mode 100644 test/layers/test_concat.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index f5ea2ad..5d95a76 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,7 @@ set(LIBRARY_SOURCES src/kernels/matmul.cu src/layers/dense.cu src/layers/conv2d.cu + src/layers/concat.cu src/layers/input.cu src/layers/activation.cu ) diff --git a/README.md b/README.md index 158f98c..8fc9e86 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Convolutional Neural Network inference library running on CUDA. - [x] Conv2d layer - [ ] Max pooling - [ ] Average pooling -- [ ] Concat layer +- [x] Concat layer - [x] Sigmoid activation - [x] ReLU activation - [x] Softmax activation diff --git a/include/layers/concat.cuh b/include/layers/concat.cuh new file mode 100644 index 0000000..4b5513f --- /dev/null +++ b/include/layers/concat.cuh @@ -0,0 +1,45 @@ +#ifndef CUDANET_CONCAT_LAYER_H +#define CUDANET_CONCAT_LAYER_H + +#include + +namespace CUDANet::Layers { + +/** + * @brief Concatenate layers + * + */ +class Concat { + public: + /** + * @brief Create a new Concat layer + * + * @param layers Layers to concatenate + */ + Concat(const unsigned int inputASize, const unsigned int inputBSize); + + /** + * @brief Destroy the Concat layer + * + */ + ~Concat(); + + /** + * @brief Forward pass of the concat layer + * + * @param d_input_A Device pointer to the first input + * @param d_input_B Device pointer to the second input + * @return Device pointer to the output + */ + float* forward(const float* d_input_A, const float* d_input_B); + + private: + unsigned int inputASize; + unsigned int inputBSize; + + float* d_output; +}; + +} // namespace CUDANet::Layers + +#endif // CUDANET_CONCAT_LAYER_H diff --git a/include/layers/ilayer.cuh b/include/layers/ilayer.cuh index f1ab5fd..3ab2109 100644 --- a/include/layers/ilayer.cuh +++ b/include/layers/ilayer.cuh @@ -64,18 +64,6 @@ class ILayer { * @brief Copy the weights and biases to the device */ virtual void toCuda() = 0; - - int inputSize; - int outputSize; - - float* d_output; - - float* d_weights; - float* d_biases; - - std::vector weights; - std::vector biases; - }; } // namespace CUDANet::Layers diff --git a/include/layers/input.cuh b/include/layers/input.cuh index ee6ecd7..fda2722 100644 --- a/include/layers/input.cuh +++ b/include/layers/input.cuh @@ -9,7 +9,7 @@ namespace CUDANet::Layers { * @brief Input layer, just copies the input to the device * */ -class Input : public ILayer { +class Input { public: /** * @brief Create a new Input layer @@ -32,15 +32,7 @@ class Input : public ILayer { */ float* forward(const float* input); - void setWeights(const float* weights); - void setBiases(const float* biases); - private: - void initializeWeights(); - void initializeBiases(); - - void toCuda(); - int inputSize; float* d_output; }; diff --git a/src/layers/concat.cu b/src/layers/concat.cu new file mode 100644 index 0000000..127017d --- /dev/null +++ b/src/layers/concat.cu @@ -0,0 +1,32 @@ +#include "concat.cuh" +#include "cuda_helper.cuh" + +using namespace CUDANet; + + +Layers::Concat::Concat(const unsigned int inputASize, const unsigned int inputBSize) + : inputASize(inputASize), inputBSize(inputBSize) { + + d_output = nullptr; + CUDA_CHECK(cudaMalloc( + (void**)&d_output, sizeof(float) * (inputASize + inputBSize) + )); + +} + +Layers::Concat::~Concat() { + cudaFree(d_output); +} + + +float* Layers::Concat::forward(const float* d_input_A, const float* d_input_B) { + CUDA_CHECK(cudaMemcpy( + d_output, d_input_A, sizeof(float) * inputASize, cudaMemcpyDeviceToDevice + )); + CUDA_CHECK(cudaMemcpy( + d_output + inputASize, d_input_B, + sizeof(float) * inputBSize, cudaMemcpyDeviceToDevice + )); + + return d_output; +} diff --git a/src/layers/input.cu b/src/layers/input.cu index d50dde3..e8015a2 100644 --- a/src/layers/input.cu +++ b/src/layers/input.cu @@ -26,11 +26,3 @@ float* Layers::Input::forward(const float* input) { return d_output; } - -void Layers::Input::setWeights(const float* weights) {} -void Layers::Input::setBiases(const float* biases) {} - -void Layers::Input::initializeWeights() {} -void Layers::Input::initializeBiases() {} - -void Layers::Input::toCuda() {} \ No newline at end of file diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4db8714..666b5f2 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -3,10 +3,11 @@ include_directories(${GTEST_INCLUDE_DIRS}) add_executable(test_main EXCLUDE_FROM_ALL - layers/test_dense.cu - layers/test_conv2d.cu - layers/test_input.cu layers/test_activation.cu + layers/test_concat.cu + layers/test_conv2d.cu + layers/test_dense.cu + layers/test_input.cu kernels/test_activation_functions.cu kernels/test_padding.cu kernels/test_matmul.cu diff --git a/test/layers/test_concat.cu b/test/layers/test_concat.cu new file mode 100644 index 0000000..a5adc97 --- /dev/null +++ b/test/layers/test_concat.cu @@ -0,0 +1,37 @@ +#include "concat.cuh" +#include +#include +#include + +TEST(ConcatLayerTest, Init) { + std::vector inputA = {0.573f, 0.619f, 0.732f, 0.055f, 0.243f}; + std::vector inputB = {0.123f, 0.321f, 0.456f, 0.789f, 0.654f, 0.123f}; + + CUDANet::Layers::Concat concat(5, 6); + + float* d_inputA; + float* d_inputB; + cudaMalloc((void**)&d_inputA, sizeof(float) * 5); + cudaMalloc((void**)&d_inputB, sizeof(float) * 6); + cudaMemcpy( + d_inputA, inputA.data(), sizeof(float) * 5, cudaMemcpyHostToDevice + ); + cudaMemcpy( + d_inputB, inputB.data(), sizeof(float) * 6, cudaMemcpyHostToDevice + ); + + float* d_output = concat.forward(d_inputA, d_inputB); + + std::vector output(11); + cudaMemcpy( + output.data(), d_output, sizeof(float) * 11, cudaMemcpyDeviceToHost + ); + + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(output[i], inputA[i]); + } + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(output[i + 5], inputB[i]); + } + cudaFree(d_output); +} \ No newline at end of file