From f17debc244ab8f108dc3d05f5f5b2aba1ac3a98d Mon Sep 17 00:00:00 2001 From: LordMathis Date: Mon, 22 Apr 2024 20:31:58 +0200 Subject: [PATCH] Implement getOutputSize and getInputSize for seq layers --- include/layers/avg_pooling.cuh | 15 +++++--- include/layers/conv2d.cuh | 15 +++++--- include/layers/dense.cuh | 19 ++++++++-- include/layers/input.cuh | 23 +++++++++--- include/layers/layer.cuh | 18 ++++++++-- include/layers/max_pooling.cuh | 15 +++++--- include/layers/output.cuh | 62 ++++++++++++++++++++------------- src/layers/avg_pooling.cu | 8 +++++ src/layers/conv2d.cu | 8 +++++ src/layers/dense.cu | 8 +++++ src/layers/input.cu | 9 +++++ src/layers/max_pooling.cu | 8 +++++ src/layers/output.cu | 9 +++++ src/model/model.cpp | 1 - test/kernels/test_matmul.cu | 6 ++-- test/layers/test_avg_pooling.cu | 4 +-- test/layers/test_conv2d.cu | 20 +++++------ test/layers/test_max_pooling.cu | 4 +-- 18 files changed, 186 insertions(+), 66 deletions(-) diff --git a/include/layers/avg_pooling.cuh b/include/layers/avg_pooling.cuh index 8de40ca..11d5431 100644 --- a/include/layers/avg_pooling.cuh +++ b/include/layers/avg_pooling.cuh @@ -20,13 +20,18 @@ class AvgPooling2D : public SequentialLayer { float* forward(const float* d_input); /** - * @brief Get the output width (/ height) of the layer + * @brief Get output size * - * @return int + * @return int output size */ - int getOutputSize() { - return outputSize; - } + int getOutputSize(); + + /** + * @brief Get input size + * + * @return int input size + */ + int getInputSize(); private: int inputSize; diff --git a/include/layers/conv2d.cuh b/include/layers/conv2d.cuh index af9bed9..9c8de89 100644 --- a/include/layers/conv2d.cuh +++ b/include/layers/conv2d.cuh @@ -80,13 +80,18 @@ class Conv2d : public WeightedLayer { std::vector getBiases(); /** - * @brief Get the output width (/ height) of the layer + * @brief Get output size * - * @return int + * @return int output size */ - int getOutputSize() { - return outputSize; - } + int getOutputSize(); + + /** + * @brief Get input size + * + * @return int input size + */ + int getInputSize(); /** * @brief Get the padding size of the layer diff --git a/include/layers/dense.cuh b/include/layers/dense.cuh index dd8c988..93356ba 100644 --- a/include/layers/dense.cuh +++ b/include/layers/dense.cuh @@ -3,8 +3,8 @@ #include -#include "layer.cuh" #include "activation.cuh" +#include "layer.cuh" namespace CUDANet::Layers { @@ -19,7 +19,8 @@ class Dense : public WeightedLayer { * * @param inputSize Size of the input vector * @param outputSize Size of the output vector - * @param activationType Activation function type ('RELU', 'SIGMOID', 'SOFTMAX' or 'NONE') + * @param activationType Activation function type ('RELU', 'SIGMOID', + * 'SOFTMAX' or 'NONE') */ Dense(int inputSize, int outputSize, Layers::ActivationType activationType); @@ -65,6 +66,20 @@ class Dense : public WeightedLayer { */ std::vector getBiases(); + /** + * @brief Get output size + * + * @return int output size + */ + int getOutputSize(); + + /** + * @brief Get input size + * + * @return int input size + */ + int getInputSize(); + private: unsigned int inputSize; unsigned int outputSize; diff --git a/include/layers/input.cuh b/include/layers/input.cuh index 144cbe9..52e43bd 100644 --- a/include/layers/input.cuh +++ b/include/layers/input.cuh @@ -13,25 +13,40 @@ class Input : public SequentialLayer { public: /** * @brief Create a new Input layer - * + * * @param inputSize Size of the input vector */ explicit Input(int inputSize); /** * @brief Destroy the Input layer - * + * */ ~Input(); /** - * @brief Forward pass of the input layer. Just copies the input to the device - * + * @brief Forward pass of the input layer. Just copies the input to the + * device + * * @param input Host pointer to the input vector * @return Device pointer to the output vector */ float* forward(const float* input); + /** + * @brief Get output size + * + * @return int output size + */ + int getOutputSize(); + + /** + * @brief Get input size + * + * @return int input size + */ + int getInputSize(); + private: int inputSize; float* d_output; diff --git a/include/layers/layer.cuh b/include/layers/layer.cuh index c89fe4f..495ab95 100644 --- a/include/layers/layer.cuh +++ b/include/layers/layer.cuh @@ -4,8 +4,8 @@ #include -#define CUDANET_SAME_PADDING(inputSize, kernelSize, stride) ((stride - 1) * inputSize - stride + kernelSize) / 2; - +#define CUDANET_SAME_PADDING(inputSize, kernelSize, stride) \ + ((stride - 1) * inputSize - stride + kernelSize) / 2; namespace CUDANet::Layers { @@ -28,6 +28,20 @@ class SequentialLayer { * @return float* Device pointer to the output */ virtual float* forward(const float* input) = 0; + + /** + * @brief Get output size + * + * @return int output size + */ + virtual int getOutputSize() = 0; + + /** + * @brief Get input size + * + * @return int input size + */ + virtual int getInputSize() = 0; }; /** diff --git a/include/layers/max_pooling.cuh b/include/layers/max_pooling.cuh index cb5d06e..6157cae 100644 --- a/include/layers/max_pooling.cuh +++ b/include/layers/max_pooling.cuh @@ -20,13 +20,18 @@ class MaxPooling2D : public SequentialLayer { float* forward(const float* d_input); /** - * @brief Get the output width (/ height) of the layer + * @brief Get output size * - * @return int + * @return int output size */ - int getOutputSize() { - return outputSize; - } + int getOutputSize(); + + /** + * @brief Get input size + * + * @return int input size + */ + int getInputSize(); private: int inputSize; diff --git a/include/layers/output.cuh b/include/layers/output.cuh index 71ff615..a8531de 100644 --- a/include/layers/output.cuh +++ b/include/layers/output.cuh @@ -6,34 +6,48 @@ namespace CUDANet::Layers { class Output : public SequentialLayer { - public: - /** - * @brief Create a new Output layer - * - * @param inputSize Size of the input vector - */ - explicit Output(int inputSize); + public: + /** + * @brief Create a new Output layer + * + * @param inputSize Size of the input vector + */ + explicit Output(int inputSize); - /** - * @brief Destroy the Output layer - * - */ - ~Output(); + /** + * @brief Destroy the Output layer + * + */ + ~Output(); - /** - * @brief Forward pass of the output layer. Just copies the input from device to host - * - * @param input Device pointer to the input vector - * @return Host pointer to the output vector - */ - float* forward(const float* input); + /** + * @brief Forward pass of the output layer. Just copies the input from + * device to host + * + * @param input Device pointer to the input vector + * @return Host pointer to the output vector + */ + float* forward(const float* input); - private: - int inputSize; - float* h_output; + /** + * @brief Get output size + * + * @return int output size + */ + int getOutputSize(); + + /** + * @brief Get input size + * + * @return int input size + */ + int getInputSize(); + + private: + int inputSize; + float* h_output; }; - -} // namespace CUDANet::Layers +} // namespace CUDANet::Layers #endif // CUDANET_OUTPUT_LAYER_H \ No newline at end of file diff --git a/src/layers/avg_pooling.cu b/src/layers/avg_pooling.cu index 951dee5..23ec5d5 100644 --- a/src/layers/avg_pooling.cu +++ b/src/layers/avg_pooling.cu @@ -49,4 +49,12 @@ float* AvgPooling2D::forward(const float* d_input) { CUDA_CHECK(cudaDeviceSynchronize()); return d_output; +} + +int AvgPooling2D::getOutputSize() { + return outputSize * outputSize * nChannels; +} + +int AvgPooling2D::getInputSize() { + return inputSize * inputSize * nChannels; } \ No newline at end of file diff --git a/src/layers/conv2d.cu b/src/layers/conv2d.cu index 3f1f829..82a3aa5 100644 --- a/src/layers/conv2d.cu +++ b/src/layers/conv2d.cu @@ -130,3 +130,11 @@ float* Conv2d::forward(const float* d_input) { return d_output; } + +int Conv2d::getOutputSize() { + return outputSize * outputSize * numFilters; +} + +int Conv2d::getInputSize() { + return inputSize * inputSize * inputChannels; +} \ No newline at end of file diff --git a/src/layers/dense.cu b/src/layers/dense.cu index 47310bb..6a13d94 100644 --- a/src/layers/dense.cu +++ b/src/layers/dense.cu @@ -108,4 +108,12 @@ void Dense::setBiases(const float* biases_input) { std::vector Dense::getBiases() { return biases; +} + +int Dense::getOutputSize() { + return outputSize; +} + +int Dense::getInputSize() { + return inputSize; } \ No newline at end of file diff --git a/src/layers/input.cu b/src/layers/input.cu index 59ec381..39531ad 100644 --- a/src/layers/input.cu +++ b/src/layers/input.cu @@ -20,3 +20,12 @@ float* Input::forward(const float* input) { return d_output; } + +int Input::getOutputSize() { + return inputSize; +} + + +int Input::getInputSize() { + return inputSize; +} \ No newline at end of file diff --git a/src/layers/max_pooling.cu b/src/layers/max_pooling.cu index 8f8e3d4..d76ae77 100644 --- a/src/layers/max_pooling.cu +++ b/src/layers/max_pooling.cu @@ -52,4 +52,12 @@ float* MaxPooling2D::forward(const float* d_input) { CUDA_CHECK(cudaDeviceSynchronize()); return d_output; +} + +int MaxPooling2D::getOutputSize() { + return outputSize * outputSize * nChannels; +} + +int MaxPooling2D::getInputSize() { + return inputSize * inputSize * nChannels; } \ No newline at end of file diff --git a/src/layers/output.cu b/src/layers/output.cu index a37afa5..5db0828 100644 --- a/src/layers/output.cu +++ b/src/layers/output.cu @@ -20,4 +20,13 @@ float* Output::forward(const float* input) { CUDA_CHECK(cudaDeviceSynchronize()); return h_output; +} + +int Output::getOutputSize() { + return inputSize; +} + + +int Output::getInputSize() { + return inputSize; } \ No newline at end of file diff --git a/src/model/model.cpp b/src/model/model.cpp index 408c1e9..8b138c9 100644 --- a/src/model/model.cpp +++ b/src/model/model.cpp @@ -43,7 +43,6 @@ float* Model::predict(const float* input) { float* d_input = inputLayer->forward(input); for (auto& layer : layers) { - std::cout << layer.first << std::endl; d_input = layer.second->forward(d_input); } diff --git a/test/kernels/test_matmul.cu b/test/kernels/test_matmul.cu index abef1b5..3941be7 100644 --- a/test/kernels/test_matmul.cu +++ b/test/kernels/test_matmul.cu @@ -187,6 +187,8 @@ TEST(MatMulTest, SumReduceTest) { cudaMemcpy(d_input, input.data(), sizeof(float) * n, cudaMemcpyHostToDevice); EXPECT_EQ(cudaStatus, cudaSuccess); + CUDANet::Utils::clear(d_sum, n); + CUDANet::Kernels::sum_reduce<<>>( d_input, d_sum, n ); @@ -208,7 +210,5 @@ TEST(MatMulTest, SumReduceTest) { EXPECT_FLOAT_EQ(expected, sum[0]); cudaFree(d_input); - cudaFree(d_sum); - - + cudaFree(d_sum); } \ No newline at end of file diff --git a/test/layers/test_avg_pooling.cu b/test/layers/test_avg_pooling.cu index 30ad3ec..29cd5a1 100644 --- a/test/layers/test_avg_pooling.cu +++ b/test/layers/test_avg_pooling.cu @@ -51,10 +51,10 @@ TEST(AvgPoolingLayerTest, AvgPoolForwardTest) { int outputSize = avgPoolingLayer.getOutputSize(); - std::vector output(outputSize * outputSize * nChannels); + std::vector output(outputSize); cudaStatus = cudaMemcpy( output.data(), d_output, - sizeof(float) * outputSize * outputSize * nChannels, + sizeof(float) * outputSize, cudaMemcpyDeviceToHost ); EXPECT_EQ(cudaStatus, cudaSuccess); diff --git a/test/layers/test_conv2d.cu b/test/layers/test_conv2d.cu index d054e95..9ab986c 100644 --- a/test/layers/test_conv2d.cu +++ b/test/layers/test_conv2d.cu @@ -82,14 +82,15 @@ TEST_F(Conv2dTest, SimpleTest) { activationType, input, kernels.data(), d_input ); - int outputSize = (inputSize - kernelSize) / stride + 1; + int outputWidth = (inputSize - kernelSize) / stride + 1; + int outputSize = outputWidth * outputWidth * numFilters; EXPECT_EQ(outputSize, conv2d.getOutputSize()); d_output = conv2d.forward(d_input); std::vector expected = {44.0f, 54.0f, 64.0f, 84.0f, 94.0f, 104.0f, 124.0f, 134.0f, 144.0f}; - std::vector output(outputSize * outputSize * numFilters); + std::vector output(outputSize); cudaStatus = cudaMemcpy( output.data(), d_output, sizeof(float) * output.size(), @@ -172,18 +173,16 @@ TEST_F(Conv2dTest, PaddedTest) { activationType, input, kernels.data(), d_input ); - EXPECT_EQ(inputSize, conv2d.getOutputSize()); + EXPECT_EQ(inputSize * inputSize * numFilters, conv2d.getOutputSize()); d_output = conv2d.forward(d_input); std::vector output( - conv2d.getOutputSize() * conv2d.getOutputSize() * numFilters + conv2d.getOutputSize() ); cudaMemcpy( output.data(), d_output, - sizeof(float) * conv2d.getOutputSize() * conv2d.getOutputSize() * - numFilters, - cudaMemcpyDeviceToHost + sizeof(float) * conv2d.getOutputSize(), cudaMemcpyDeviceToHost ); // Generated by tools/generate_conv2d_test.py @@ -259,17 +258,16 @@ TEST_F(Conv2dTest, StridedPaddedConvolution) { activationType, input, kernels.data(), d_input ); - EXPECT_EQ(inputSize, conv2d.getOutputSize()); + EXPECT_EQ(inputSize * inputSize * numFilters, conv2d.getOutputSize()); d_output = conv2d.forward(d_input); std::vector output( - conv2d.getOutputSize() * conv2d.getOutputSize() * numFilters + conv2d.getOutputSize() ); cudaMemcpy( output.data(), d_output, - sizeof(float) * conv2d.getOutputSize() * conv2d.getOutputSize() * - numFilters, + sizeof(float) * conv2d.getOutputSize(), cudaMemcpyDeviceToHost ); diff --git a/test/layers/test_max_pooling.cu b/test/layers/test_max_pooling.cu index 8cd6211..09c5214 100644 --- a/test/layers/test_max_pooling.cu +++ b/test/layers/test_max_pooling.cu @@ -51,10 +51,10 @@ TEST(MaxPoolingLayerTest, MaxPoolForwardTest) { int outputSize = maxPoolingLayer.getOutputSize(); - std::vector output(outputSize * outputSize * nChannels); + std::vector output(outputSize); cudaStatus = cudaMemcpy( output.data(), d_output, - sizeof(float) * outputSize * outputSize * nChannels, + sizeof(float) * outputSize, cudaMemcpyDeviceToHost ); EXPECT_EQ(cudaStatus, cudaSuccess);