diff --git a/include/cudanet.cuh b/include/cudanet.cuh index 3caa89d..24cec89 100644 --- a/include/cudanet.cuh +++ b/include/cudanet.cuh @@ -12,7 +12,7 @@ #include "activation.hpp" #include "add.hpp" #include "avg_pooling.hpp" -#include "batch_norm.cuh" +#include "batch_norm.hpp" #include "concat.hpp" #include "conv2d.hpp" #include "dense.hpp" diff --git a/include/layers/batch_norm.cuh b/include/layers/batch_norm.hpp similarity index 85% rename from include/layers/batch_norm.cuh rename to include/layers/batch_norm.hpp index 45402ec..a2ae630 100644 --- a/include/layers/batch_norm.cuh +++ b/include/layers/batch_norm.hpp @@ -10,7 +10,12 @@ namespace CUDANet::Layers { class BatchNorm2d : public WeightedLayer, public TwoDLayer { public: - BatchNorm2d(shape2d inputSize, int inputChannels, float epsilon, ActivationType activationType); + BatchNorm2d( + shape2d inputSize, + int inputChannels, + float epsilon, + ActivationType activationType + ); ~BatchNorm2d(); @@ -52,27 +57,27 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer { /** * @brief Set the Running Mean - * - * @param running_mean_input + * + * @param running_mean_input */ void setRunningMean(const float* running_mean_input); /** * @brief Get the Running Mean - * + * */ std::vector getRunningMean(); /** * @brief Set the Running Var - * - * @param running_mean_input + * + * @param running_mean_input */ void setRunningVar(const float* running_mean_input); /** * @brief Get the Running Var - * + * */ std::vector getRunningVar(); @@ -93,12 +98,14 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer { shape2d getOutputDims(); private: - shape2d inputSize; - int inputChannels; + int inputChannels; + float epsilon; int gridSize; +#ifdef USE_CUDA + float* d_output; float* d_running_mean; @@ -110,6 +117,19 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer { float* d_weights; float* d_biases; + void initCUDA(); + void delCUDA(); + + /** + * @brief Copy weights and biases to the device + * + */ + void toCuda(); + + float* forwardCUDA(const float* d_input); + +#endif + std::vector weights; std::vector biases; @@ -118,6 +138,8 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer { Activation* activation; + float* forwardCPU(const float* input); + /** * @brief Initialize weights of the batchnorm layer with zeros * @@ -141,12 +163,6 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer { * */ void initializeRunningVar(); - - /** - * @brief Copy weights and biases to the device - * - */ - void toCuda(); }; } // namespace CUDANet::Layers diff --git a/src/layers/batch_norm.cu b/src/backends/cuda/layers/batch_norm.cu similarity index 53% rename from src/layers/batch_norm.cu rename to src/backends/cuda/layers/batch_norm.cu index 02c33cb..d7bd6c3 100644 --- a/src/layers/batch_norm.cu +++ b/src/backends/cuda/layers/batch_norm.cu @@ -1,7 +1,7 @@ #include #include "activation.hpp" -#include "batch_norm.cuh" +#include "batch_norm.hpp" #include "cuda_helper.cuh" #include "layer.hpp" #include "matmul.cuh" @@ -9,17 +9,7 @@ using namespace CUDANet::Layers; -BatchNorm2d::BatchNorm2d( - shape2d inputSize, - int inputChannels, - float epsilon, - ActivationType activationType -) - : inputSize(inputSize), inputChannels(inputChannels) { - activation = new Activation( - activationType, inputSize.first * inputSize.second * inputChannels - ); - +void BatchNorm2d::initCUDA() { d_output = nullptr; CUDA_CHECK(cudaMalloc( (void **)&d_output, @@ -27,14 +17,14 @@ BatchNorm2d::BatchNorm2d( )); d_running_mean = nullptr; - CUDA_CHECK(cudaMalloc( - (void **)&d_running_mean, sizeof(float) * inputChannels - )); + CUDA_CHECK( + cudaMalloc((void **)&d_running_mean, sizeof(float) * inputChannels) + ); d_running_var = nullptr; - CUDA_CHECK(cudaMalloc( - (void **)&d_running_var, sizeof(float) * inputChannels - )); + CUDA_CHECK( + cudaMalloc((void **)&d_running_var, sizeof(float) * inputChannels) + ); d_weights = nullptr; CUDA_CHECK(cudaMalloc((void **)&d_weights, sizeof(float) * inputChannels)); @@ -55,24 +45,11 @@ BatchNorm2d::BatchNorm2d( cudaMemcpy(d_epsilon, &epsilon, sizeof(float), cudaMemcpyHostToDevice) ); - weights.resize(inputChannels); - biases.resize(inputChannels); - - running_mean.resize(inputChannels); - running_var.resize(inputChannels); - - initializeWeights(); - initializeBiases(); - initializeRunningMean(); - initializeRunningVar(); - - toCuda(); - gridSize = (inputSize.first * inputSize.second + BLOCK_SIZE - 1) / BLOCK_SIZE; } -BatchNorm2d::~BatchNorm2d() { +void BatchNorm2d::delCUDA() { cudaFree(d_output); cudaFree(d_running_mean); cudaFree(d_running_var); @@ -82,58 +59,6 @@ BatchNorm2d::~BatchNorm2d() { cudaFree(d_epsilon); } -void BatchNorm2d::initializeWeights() { - std::fill(weights.begin(), weights.end(), 1.0f); -} - -void BatchNorm2d::initializeBiases() { - std::fill(biases.begin(), biases.end(), 0.0f); -} - -void BatchNorm2d::initializeRunningMean() { - std::fill(running_mean.begin(), running_mean.end(), 0.0f); -} - -void BatchNorm2d::initializeRunningVar() { - std::fill(running_var.begin(), running_var.end(), 1.0f); -} - -void BatchNorm2d::setWeights(const float *weights_input) { - std::copy(weights_input, weights_input + weights.size(), weights.begin()); - toCuda(); -} - -std::vector BatchNorm2d::getWeights() { - return weights; -} - -void BatchNorm2d::setBiases(const float *biases_input) { - std::copy(biases_input, biases_input + biases.size(), biases.begin()); - toCuda(); -} - -std::vector BatchNorm2d::getBiases() { - return biases; -} - -void BatchNorm2d::setRunningMean(const float* running_mean_input) { - std::copy(running_mean_input, running_mean_input + inputChannels, running_mean.begin()); - toCuda(); -} - -std::vector BatchNorm2d::getRunningMean() { - return running_mean; -} - -void BatchNorm2d::setRunningVar(const float* running_var_input) { - std::copy(running_var_input, running_var_input + inputChannels, running_var.begin()); - toCuda(); -} - -std::vector BatchNorm2d::getRunningVar() { - return running_var; -} - void BatchNorm2d::toCuda() { CUDA_CHECK(cudaMemcpy( d_weights, weights.data(), sizeof(float) * inputChannels, @@ -153,22 +78,9 @@ void BatchNorm2d::toCuda() { )); } -int BatchNorm2d::getInputSize() { - return inputSize.first * inputSize.second * inputChannels; -} - -int BatchNorm2d::getOutputSize() { - return inputSize.first * inputSize.second * inputChannels; -} - -shape2d BatchNorm2d::getOutputDims() { - return inputSize; -} - -float *BatchNorm2d::forward(const float *d_input) { +float *BatchNorm2d::forwardCUDA(const float *d_input) { // Compute per-channel batch normalization for (int i = 0; i < inputChannels; i++) { - // Subtract mean from input Kernels::vec_scalar_sub<<>>( d_input + i * inputSize.first * inputSize.second, @@ -181,17 +93,14 @@ float *BatchNorm2d::forward(const float *d_input) { Kernels::vec_scale<<>>( d_output + i * inputSize.first * inputSize.second, d_output + i * inputSize.first * inputSize.second, - &d_running_var[i], - d_epsilon, - inputSize.first * inputSize.second + &d_running_var[i], d_epsilon, inputSize.first * inputSize.second ); CUDA_CHECK(cudaGetLastError()); // Multiply by weights Kernels::vec_scalar_mul<<>>( d_output + i * inputSize.first * inputSize.second, - d_output + i * inputSize.first * inputSize.second, - &d_weights[i], + d_output + i * inputSize.first * inputSize.second, &d_weights[i], inputSize.first * inputSize.second ); CUDA_CHECK(cudaGetLastError()); @@ -199,8 +108,7 @@ float *BatchNorm2d::forward(const float *d_input) { // Add biases Kernels::vec_scalar_add<<>>( d_output + i * inputSize.first * inputSize.second, - d_output + i * inputSize.first * inputSize.second, - &d_biases[i], + d_output + i * inputSize.first * inputSize.second, &d_biases[i], inputSize.first * inputSize.second ); CUDA_CHECK(cudaGetLastError()); diff --git a/src/layers/batch_norm.cpp b/src/layers/batch_norm.cpp new file mode 100644 index 0000000..49e1fa7 --- /dev/null +++ b/src/layers/batch_norm.cpp @@ -0,0 +1,133 @@ +#include "batch_norm.hpp" + +#include +#include + +#include "activation.hpp" +#include "layer.hpp" + +using namespace CUDANet::Layers; + +BatchNorm2d::BatchNorm2d( + shape2d inputSize, + int inputChannels, + float epsilon, + ActivationType activationType +) + : inputSize(inputSize), inputChannels(inputChannels), epsilon(epsilon) { + activation = new Activation( + activationType, inputSize.first * inputSize.second * inputChannels + ); + + weights.resize(inputChannels); + biases.resize(inputChannels); + + running_mean.resize(inputChannels); + running_var.resize(inputChannels); + + initializeWeights(); + initializeBiases(); + initializeRunningMean(); + initializeRunningVar(); + +#ifdef USE_CUDA + initCUDA(); + toCuda(); +#endif +} + +BatchNorm2d::~BatchNorm2d() { +#ifdef USE_CUDA + delCUDA(); +#endif +} + +void BatchNorm2d::initializeWeights() { + std::fill(weights.begin(), weights.end(), 1.0f); +} + +void BatchNorm2d::initializeBiases() { + std::fill(biases.begin(), biases.end(), 0.0f); +} + +void BatchNorm2d::initializeRunningMean() { + std::fill(running_mean.begin(), running_mean.end(), 0.0f); +} + +void BatchNorm2d::initializeRunningVar() { + std::fill(running_var.begin(), running_var.end(), 1.0f); +} + +void BatchNorm2d::setWeights(const float* weights_input) { + std::copy(weights_input, weights_input + weights.size(), weights.begin()); +#ifdef USE_CUDA + toCuda(); +#endif +} + +std::vector BatchNorm2d::getWeights() { + return weights; +} + +void BatchNorm2d::setBiases(const float* biases_input) { + std::copy(biases_input, biases_input + biases.size(), biases.begin()); +#ifdef USE_CUDA + toCuda(); +#endif +} + +std::vector BatchNorm2d::getBiases() { + return biases; +} + +void BatchNorm2d::setRunningMean(const float* running_mean_input) { + std::copy( + running_mean_input, running_mean_input + inputChannels, + running_mean.begin() + ); +#ifdef USE_CUDA + toCuda(); +#endif +} + +std::vector BatchNorm2d::getRunningMean() { + return running_mean; +} + +void BatchNorm2d::setRunningVar(const float* running_var_input) { + std::copy( + running_var_input, running_var_input + inputChannels, + running_var.begin() + ); +#ifdef USE_CUDA + toCuda(); +#endif +} + +std::vector BatchNorm2d::getRunningVar() { + return running_var; +} + +int BatchNorm2d::getInputSize() { + return inputSize.first * inputSize.second * inputChannels; +} + +int BatchNorm2d::getOutputSize() { + return inputSize.first * inputSize.second * inputChannels; +} + +shape2d BatchNorm2d::getOutputDims() { + return inputSize; +} + +float* BatchNorm2d::forwardCPU(const float* input) { + throw std::logic_error("Not implemented"); +} + +float* BatchNorm2d::forward(const float* input) { +#ifdef USE_CUDA + return forwardCUDA(input); +#else + return forwardCPU(input); +#endif +} \ No newline at end of file diff --git a/src/model/model.cpp b/src/model/model.cpp index c5ef706..025eb47 100644 --- a/src/model/model.cpp +++ b/src/model/model.cpp @@ -9,7 +9,7 @@ #include "input.hpp" #include "layer.hpp" -#include "batch_norm.cuh" +#include "batch_norm.hpp" using namespace CUDANet; diff --git a/test/cuda/layers/test_batch_norm.cu b/test/cuda/layers/test_batch_norm.cu index e6ea673..b12874a 100644 --- a/test/cuda/layers/test_batch_norm.cu +++ b/test/cuda/layers/test_batch_norm.cu @@ -4,7 +4,7 @@ #include #include "activation.hpp" -#include "batch_norm.cuh" +#include "batch_norm.hpp" class BatchNormLayerTest : public ::testing::Test { protected: