From a0fc1b00ae3ee79de26a5387fb15612e72bea9e6 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Tue, 19 Mar 2024 22:04:58 +0100 Subject: [PATCH] Implement max pooling layer --- include/kernels/pooling.cuh | 30 ++++++++++ include/layers/conv2d.cuh | 19 +++--- include/layers/max_pooling.cuh | 42 +++++++++++++ src/kernels/activation_functions.cu | 12 ++-- src/kernels/convolution.cu | 4 +- src/kernels/matmul.cu | 6 +- src/kernels/pooling.cu | 92 +++++++++++++++++++++++++++++ src/layers/conv2d.cu | 2 +- src/layers/max_pooling.cu | 59 ++++++++++++++++++ 9 files changed, 245 insertions(+), 21 deletions(-) create mode 100644 include/kernels/pooling.cuh create mode 100644 include/layers/max_pooling.cuh create mode 100644 src/kernels/pooling.cu create mode 100644 src/layers/max_pooling.cu diff --git a/include/kernels/pooling.cuh b/include/kernels/pooling.cuh new file mode 100644 index 0000000..a4722ea --- /dev/null +++ b/include/kernels/pooling.cuh @@ -0,0 +1,30 @@ +#ifndef CUDANET_POOLING_H +#define CUDANET_POOLING_H + +#include + +namespace CUDANet::Kernels { + +__global__ void max_pooling( + const float* __restrict__ d_input, + float* __restrict__ d_output, + const int inputSize, + const int nChannels, + const int poolingSize, + const int stride, + const int paddingSize +); + +__global__ void avg_pooling( + const float* __restrict__ d_input, + float* __restrict__ d_output, + const int inputSize, + const int nChannels, + const int poolingSize, + const int stride, + const int paddingSize +); + +} // namespace CUDANet::Kernels + +#endif // CUDANET_POOLING_H \ No newline at end of file diff --git a/include/layers/conv2d.cuh b/include/layers/conv2d.cuh index 40ad3c5..5cd6ecf 100644 --- a/include/layers/conv2d.cuh +++ b/include/layers/conv2d.cuh @@ -25,16 +25,17 @@ class Conv2d : public WeightedLayer { * @param stride Convolution stride * @param numFilters Number of output filters * @param padding Padding type ('SAME' or 'VALID') - * @param activationType Activation function type ('RELU', 'SIGMOID', 'SOFTMAX' or 'NONE') + * @param activationType Activation function type ('RELU', 'SIGMOID', + * 'SOFTMAX' or 'NONE') */ Conv2d( - int inputSize, - int inputChannels, - int kernelSize, - int stride, - int numFilters, - Layers::Padding padding, - Layers::ActivationType activationType + int inputSize, + int inputChannels, + int kernelSize, + int stride, + int numFilters, + Padding padding, + ActivationType activationType ); /** @@ -107,7 +108,7 @@ class Conv2d : public WeightedLayer { float* d_biases; // Kernels - Layers::Activation activation; + Activation activation; /** * @brief Initialize weights of the convolutional layer with zeros diff --git a/include/layers/max_pooling.cuh b/include/layers/max_pooling.cuh new file mode 100644 index 0000000..542777e --- /dev/null +++ b/include/layers/max_pooling.cuh @@ -0,0 +1,42 @@ +#ifndef CUDANET_MAX_POOLING_H +#define CUDANET_MAX_POOLING_H + +#include + +#include "layer.cuh" +#include "activation.cuh" + +namespace CUDANet::Layers { + +class MaxPooling2D : public SequentialLayer { + public: + MaxPooling2D( + int inputSize, + int nChannels, + int poolingSize, + int stride, + Padding padding, + ActivationType activationType + ); + ~MaxPooling2D(); + + float* forward(const float* d_input); + + private: + int inputSize; + int nChannels; + int poolingSize; + int stride; + int paddingSize; + + int outputSize; + int gridSize; + + float* d_output; + + Activation activation; +}; + +} // namespace CUDANet::Layers + +#endif // CUDANET_MAX_POOLING_H \ No newline at end of file diff --git a/src/kernels/activation_functions.cu b/src/kernels/activation_functions.cu index 3315d3c..5e8e960 100644 --- a/src/kernels/activation_functions.cu +++ b/src/kernels/activation_functions.cu @@ -3,9 +3,9 @@ #include "activation_functions.cuh" #include "cuda_helper.cuh" -using namespace CUDANet::Kernels; +using namespace CUDANet; -__global__ void sigmoid( +__global__ void Kernels::sigmoid( const float* __restrict__ src, float* __restrict__ dst, const unsigned int len @@ -18,7 +18,7 @@ __global__ void sigmoid( } } -__global__ void relu( +__global__ void Kernels::relu( const float* __restrict__ src, float* __restrict__ dst, const unsigned int len @@ -31,7 +31,7 @@ __global__ void relu( } } -__global__ void softmax_exp( +__global__ void Kernels::softmax_exp( const float* __restrict__ src, float* __restrict__ dst, const unsigned int len @@ -44,7 +44,7 @@ __global__ void softmax_exp( } } -__global__ void softmax_sum( +__global__ void Kernels::softmax_sum( const float* __restrict__ d_vector, float* __restrict__ d_output, const unsigned int w @@ -66,7 +66,7 @@ __global__ void softmax_sum( } } -__global__ void softmax_div( +__global__ void Kernels::softmax_div( const float* __restrict__ src, float* __restrict__ dst, const float* __restrict__ sum, diff --git a/src/kernels/convolution.cu b/src/kernels/convolution.cu index 28d5ca2..9f1efc2 100644 --- a/src/kernels/convolution.cu +++ b/src/kernels/convolution.cu @@ -2,9 +2,9 @@ #include "convolution.cuh" -using namespace CUDANet::Kernels; +using namespace CUDANet; -__global__ void convolution( +__global__ void Kernels::convolution( const float* __restrict__ d_input, const float* __restrict__ d_kernel, float* __restrict__ d_output, diff --git a/src/kernels/matmul.cu b/src/kernels/matmul.cu index 3bb2e15..24ac1b4 100644 --- a/src/kernels/matmul.cu +++ b/src/kernels/matmul.cu @@ -1,9 +1,9 @@ #include "cuda_helper.cuh" #include "matmul.cuh" -using namespace CUDANet::Kernels; +using namespace CUDANet; -__global__ void mat_vec_mul( +__global__ void Kernels::mat_vec_mul( const float* __restrict__ d_matrix, const float* __restrict__ d_vector, float* __restrict__ d_output, @@ -37,7 +37,7 @@ __global__ void mat_vec_mul( d_output[tid] = temp; } -__global__ void vec_vec_add( +__global__ void Kernels::vec_vec_add( const float* __restrict__ d_vector1, const float* __restrict__ d_vector2, float* __restrict__ d_output, diff --git a/src/kernels/pooling.cu b/src/kernels/pooling.cu new file mode 100644 index 0000000..92df3a9 --- /dev/null +++ b/src/kernels/pooling.cu @@ -0,0 +1,92 @@ +#include "pooling.cuh" + +#include "cuda_helper.cuh" + +using namespace CUDANet; + +__global__ void Kernels::max_pooling( + const float* __restrict__ d_input, + float* __restrict__ d_output, + const int inputSize, + const int nChannels, + const int poolingSize, + const int stride, + const int paddingSize +) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid >= inputSize * inputSize * nChannels) { + return; + } + + // Get output index + int c = tid / (inputSize * inputSize); + int i = tid % (inputSize * inputSize) / inputSize; + int j = tid % inputSize; + + float max = 0.0f; + + for (int k = 0; k < poolingSize; k++) { + for (int l = 0; l < poolingSize; l++) { + + if (i * stride + k < paddingSize || + i * stride + k >= (inputSize + paddingSize) || + j * stride + l < paddingSize || + j * stride + l >= (inputSize + paddingSize)) { + continue; + } + + + int inputIndex = c * inputSize * inputSize + + (i * stride + k - paddingSize) * inputSize + + (j * stride + l - paddingSize); + + if (d_input[inputIndex] > max) { + max = d_input[inputIndex]; + } + } + } + + d_output[tid] = max; +} + +__global__ void Kernels::avg_pooling( + const float* __restrict__ d_input, + float* __restrict__ d_output, + const int inputSize, + const int nChannels, + const int poolingSize, + const int stride, + const int paddingSize +) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid >= inputSize * inputSize * nChannels) { + return; + } + + // Get output index + int c = tid / (inputSize * inputSize); + int i = tid % (inputSize * inputSize) / inputSize; + int j = tid % inputSize; + + float sum = 0.0f; + + for (int k = 0; k < poolingSize; k++) { + for (int l = 0; l < poolingSize; l++) { + + if (i * stride + k < paddingSize || + i * stride + k >= (inputSize + paddingSize) || + j * stride + l < paddingSize || + j * stride + l >= (inputSize + paddingSize)) { + continue; + } + + int inputIndex = c * inputSize * inputSize + + (i * stride + k - paddingSize) * inputSize + + (j * stride + l - paddingSize); + + sum += d_input[inputIndex]; + } + } + + d_output[tid] = sum / (poolingSize * poolingSize); +} \ No newline at end of file diff --git a/src/layers/conv2d.cu b/src/layers/conv2d.cu index 79e3c88..3b046ec 100644 --- a/src/layers/conv2d.cu +++ b/src/layers/conv2d.cu @@ -39,7 +39,7 @@ Conv2d::Conv2d( break; } - activation = Layers::Activation( + activation = Activation( activationType, outputSize * outputSize * numFilters ); diff --git a/src/layers/max_pooling.cu b/src/layers/max_pooling.cu new file mode 100644 index 0000000..e485b7c --- /dev/null +++ b/src/layers/max_pooling.cu @@ -0,0 +1,59 @@ +#include "max_pooling.cuh" +#include "cuda_helper.cuh" +#include "pooling.cuh" + +using namespace CUDANet::Layers; + + +MaxPooling2D::MaxPooling2D( + int inputSize, + int nChannels, + int poolingSize, + int stride, + Padding padding, + ActivationType activationType + ) + : inputSize(inputSize), nChannels(nChannels), poolingSize(poolingSize), stride(stride) { + + + switch (padding) { + case SAME: + outputSize = inputSize; + paddingSize = ((stride - 1) * inputSize - stride + poolingSize) / 2; + break; + + case VALID: + paddingSize = 0; + outputSize = (inputSize - poolingSize) / stride + 1; + break; + + default: + break; + } + + activation = Activation( + activationType, outputSize * outputSize * nChannels + ); + + d_output = nullptr; + CUDA_CHECK(cudaMalloc( + (void**)&d_output, sizeof(float) * outputSize * outputSize * nChannels + )); + + gridSize = (outputSize * outputSize * nChannels + BLOCK_SIZE - 1) / BLOCK_SIZE; + +} + + +MaxPooling2D::~MaxPooling2D() { + cudaFree(d_output); +} + + +float* MaxPooling2D::forward(const float* d_input) { + Kernels::max_pooling<<>>( + d_input, d_output, inputSize, nChannels, poolingSize, stride, paddingSize + ); + + return d_output; +} \ No newline at end of file