From f60d62f6bdb083225226d2140b99c4297bfd7c83 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sun, 28 Apr 2024 19:58:00 +0200 Subject: [PATCH] Implement batch norm layer --- include/kernels/matmul.cuh | 38 ++++++++- include/layers/batch_norm.cuh | 123 +++++++++++++++++++++++++++ src/kernels/matmul.cu | 26 ++++++ src/layers/batch_norm.cu | 156 ++++++++++++++++++++++++++++++++++ 4 files changed, 340 insertions(+), 3 deletions(-) create mode 100644 include/layers/batch_norm.cuh create mode 100644 src/layers/batch_norm.cu diff --git a/include/kernels/matmul.cuh b/include/kernels/matmul.cuh index da7cd5f..f4b8e9a 100644 --- a/include/kernels/matmul.cuh +++ b/include/kernels/matmul.cuh @@ -38,7 +38,7 @@ __global__ void vec_vec_add( ); /** - * @brief Add scalar to each element of the vector + * @brief Sub scalar from each element of the vector * * @param d_vector * @param d_scalar @@ -54,7 +54,23 @@ __global__ void vec_scalar_sub( ); /** - * @brief Softmax activation function kernel + * @brief Add scalar to each element of the vector + * + * @param d_src + * @param d_out + * @param d_scalar + * @param len + * @return __global__ + */ +__global__ void vec_scalar_add( + const float* __restrict__ d_src, + float* __restrict__ d_out, + const float* __restrict__ d_scalar, + const unsigned int len +); + +/** + * @brief Divide each element of the vector by a scalar * * @param src Pointer to the source array * @param dst Pointer to the destination array @@ -68,7 +84,23 @@ __global__ void vec_scalar_div( ); /** - * @brief Softmax activation exponentiation kernel + * @brief Multiply each element of the vector by a scalar + * + * @param d_src + * @param d_out + * @param d_scalar + * @param len + * @return __global__ + */ +__global__ void vec_scalar_mul( + const float* __restrict__ d_src, + float* __restrict__ d_out, + const float* __restrict__ d_scalar, + const unsigned int len +); + +/** + * @brief Exponentiate each element of the vector * * @param src Pointer to the source array * @param dst Pointer to the destination array diff --git a/include/layers/batch_norm.cuh b/include/layers/batch_norm.cuh new file mode 100644 index 0000000..142c467 --- /dev/null +++ b/include/layers/batch_norm.cuh @@ -0,0 +1,123 @@ +#ifndef CUDANET_BATCH_NORM_H +#define CUDANET_BATCH_NORM_H + +#include + +#include "activation.cuh" +#include "layer.cuh" + +namespace CUDANet::Layers { + +class BatchNorm : public WeightedLayer { + public: + BatchNorm(int inputSize, int inputChannels, ActivationType activationType); + + ~BatchNorm(); + + /** + * @brief Compute the forward pass of the batchnorm layer + * + * @param d_input Device pointer to the input + * @return float* Device pointer to the output + */ + float* forward(const float* d_input); + + /** + * @brief Set the weights of the batchnorm layer + * + * @param weights_input Pointer to the weights + */ + void setWeights(const float* weights_input); + + /** + * @brief Get the weights of the batchnorm layer + * + * @return std::vector + */ + std::vector getWeights(); + + /** + * @brief Set the biases of the batchnorm layer + * + * @param biases_input Pointer to the biases + */ + void setBiases(const float* biases_input); + + /** + * @brief Get the biases of the batchnorm layer + * + * @return std::vector + */ + std::vector getBiases(); + + /** + * @brief Get output size + * + * @return int output size + */ + int getOutputSize(); + + /** + * @brief Get input size + * + * @return int input size + */ + int getInputSize(); + + private: + + int inputSize; + int inputChannels; + + int gridSize; + + float* d_output; + + float* d_mean; + float* d_sqrt_var; + + float* d_weights; + float* d_biases; + + std::vector weights; + std::vector biases; + + std::vector mean; + std::vector sqrt_var; + + Activation* activation; + + /** + * @brief Initialize weights of the batchnorm layer with zeros + * + */ + void initializeWeights(); + + /** + * @brief Initialize biases of the batchnorm layer with zeros + * + */ + void initializeBiases(); + + /** + * @brief Initialize mean of the batchnorm layer with zeros + * + */ + void initializeMean(); + + /** + * @brief Initialize sqrt of variance of the batchnorm layer with ones + * + */ + void initializeSqrtVar(); + + /** + * @brief Copy weights and biases to the device + * + */ + void toCuda(); +}; + +} // namespace CUDANet::Layers + +#endif // CUDANET_BATCH_NORM_H \ No newline at end of file diff --git a/src/kernels/matmul.cu b/src/kernels/matmul.cu index ae27cc5..b8c484b 100644 --- a/src/kernels/matmul.cu +++ b/src/kernels/matmul.cu @@ -49,6 +49,19 @@ __global__ void Kernels::vec_scalar_sub( d_out[tid] = d_src[tid] - *d_scalar; } +__global__ void Kernels::vec_scalar_add( + const float* __restrict__ d_src, + float* __restrict__ d_out, + const float* __restrict__ d_scalar, + const unsigned int len +) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid >= len) { + return; + } + d_out[tid] = d_src[tid] + *d_scalar; +} + __global__ void Kernels::vec_scalar_div( const float* __restrict__ d_src, float* __restrict__ d_out, @@ -62,6 +75,19 @@ __global__ void Kernels::vec_scalar_div( d_out[tid] = d_src[tid] / *d_scalar; } +__global__ void Kernels::vec_scalar_mul( + const float* __restrict__ d_src, + float* __restrict__ d_out, + const float* __restrict__ d_scalar, + const unsigned int len +) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid >= len) { + return; + } + d_out[tid] = d_src[tid] * *d_scalar; +} + __global__ void Kernels::vec_exp( const float* __restrict__ src, float* __restrict__ dst, diff --git a/src/layers/batch_norm.cu b/src/layers/batch_norm.cu new file mode 100644 index 0000000..c8cee1e --- /dev/null +++ b/src/layers/batch_norm.cu @@ -0,0 +1,156 @@ +#include + +#include "activation.cuh" +#include "batch_norm.cuh" +#include "cuda_helper.cuh" +#include "layer.cuh" +#include "matmul.cuh" + +using namespace CUDANet::Layers; + +BatchNorm::BatchNorm( + int inputSize, + int inputChannels, + ActivationType activationType +) + : inputSize(inputSize), inputChannels(inputChannels) { + activation = + new Activation(activationType, inputSize * inputSize * inputChannels); + + d_output = nullptr; + CUDA_CHECK(cudaMalloc( + (void **)&d_output, + sizeof(float) * inputSize * inputSize * inputChannels + )); + + d_mean = nullptr; + CUDA_CHECK(cudaMalloc((void **)&d_mean, sizeof(float) * inputChannels)); + + d_sqrt_var = nullptr; + CUDA_CHECK(cudaMalloc((void **)&d_sqrt_var, sizeof(float) * inputChannels)); + + d_weights = nullptr; + CUDA_CHECK(cudaMalloc((void **)&d_weights, sizeof(float) * inputChannels)); + + d_biases = nullptr; + CUDA_CHECK(cudaMalloc((void **)&d_biases, sizeof(float) * inputChannels)); + + weights.resize(inputChannels); + biases.resize(inputChannels); + mean.resize(inputChannels); + sqrt_var.resize(inputChannels); + + initializeWeights(); + initializeBiases(); + initializeMean(); + initializeSqrtVar(); + + toCuda(); + + gridSize = + (inputSize * inputSize * inputChannels + BLOCK_SIZE - 1) / BLOCK_SIZE; +} + +BatchNorm::~BatchNorm() { + cudaFree(d_output); + cudaFree(d_mean); + cudaFree(d_sqrt_var); + cudaFree(d_weights); + cudaFree(d_biases); +} + +void BatchNorm::initializeWeights() { + std::fill(weights.begin(), weights.end(), 1.0f); +} + +void BatchNorm::initializeBiases() { + std::fill(biases.begin(), biases.end(), 0.0f); +} + +void BatchNorm::initializeMean() { + std::fill(mean.begin(), mean.end(), 0.0f); +} + +void BatchNorm::initializeSqrtVar() { + std::fill(sqrt_var.begin(), sqrt_var.end(), 1.0f); +} + +void BatchNorm::setWeights(const float *weights_input) { + std::copy(weights_input, weights_input + weights.size(), weights.begin()); + toCuda(); +} + +std::vector BatchNorm::getWeights() { + return weights; +} + +void BatchNorm::setBiases(const float *biases_input) { + std::copy(biases_input, biases_input + biases.size(), biases.begin()); + toCuda(); +} + +std::vector BatchNorm::getBiases() { + return biases; +} + +void BatchNorm::toCuda() { + CUDA_CHECK(cudaMemcpy( + d_weights, weights.data(), sizeof(float) * inputChannels, + cudaMemcpyHostToDevice + )); + CUDA_CHECK(cudaMemcpy( + d_biases, biases.data(), sizeof(float) * inputChannels, + cudaMemcpyHostToDevice + )); + CUDA_CHECK(cudaMemcpy( + d_mean, mean.data(), sizeof(float) * inputChannels, + cudaMemcpyHostToDevice + )); + CUDA_CHECK(cudaMemcpy( + d_sqrt_var, sqrt_var.data(), sizeof(float) * inputChannels, + cudaMemcpyHostToDevice + )); +} + +int BatchNorm::getInputSize() { + return inputSize * inputSize * inputChannels; +} + +int BatchNorm::getOutputSize() { + return inputSize * inputSize * inputChannels; +} + +float *BatchNorm::forward(const float *d_input) { + + for (int i = 0; i < inputChannels; i++) { + Kernels::vec_scalar_sub<<>>( + d_input + i * inputSize * inputSize, + d_output + i * inputSize * inputSize, + &d_mean[i], + inputSize * inputSize + ); + + Kernels::vec_scalar_div<<>>( + d_output + i * inputSize * inputSize, + d_output + i * inputSize * inputSize, + &d_sqrt_var[i], + inputSize * inputSize + ); + + Kernels::vec_scalar_mul<<>>( + d_output + i * inputSize * inputSize, + d_output + i * inputSize * inputSize, + &d_weights[i], + inputSize * inputSize + ); + + Kernels::vec_scalar_add<<>>( + d_output + i * inputSize * inputSize, + d_output + i * inputSize * inputSize, + &d_biases[i], + inputSize * inputSize + ); + } + + return d_output; +} \ No newline at end of file