diff --git a/src/layers/batch_norm.cu b/src/layers/batch_norm.cu index c8cee1e..1df5967 100644 --- a/src/layers/batch_norm.cu +++ b/src/layers/batch_norm.cu @@ -129,6 +129,7 @@ float *BatchNorm::forward(const float *d_input) { &d_mean[i], inputSize * inputSize ); + CUDA_CHECK(cudaGetLastError()); Kernels::vec_scalar_div<<>>( d_output + i * inputSize * inputSize, @@ -136,6 +137,7 @@ float *BatchNorm::forward(const float *d_input) { &d_sqrt_var[i], inputSize * inputSize ); + CUDA_CHECK(cudaGetLastError()); Kernels::vec_scalar_mul<<>>( d_output + i * inputSize * inputSize, @@ -143,6 +145,7 @@ float *BatchNorm::forward(const float *d_input) { &d_weights[i], inputSize * inputSize ); + CUDA_CHECK(cudaGetLastError()); Kernels::vec_scalar_add<<>>( d_output + i * inputSize * inputSize, @@ -150,6 +153,7 @@ float *BatchNorm::forward(const float *d_input) { &d_biases[i], inputSize * inputSize ); + CUDA_CHECK(cudaGetLastError()); } return d_output; diff --git a/test/layers/test_batch_norm.cu b/test/layers/test_batch_norm.cu new file mode 100644 index 0000000..c791d33 --- /dev/null +++ b/test/layers/test_batch_norm.cu @@ -0,0 +1,80 @@ +#include +#include + +#include + +#include "activation.cuh" +#include "batch_norm.cuh" + +TEST(BatchNormLayerTest, BatchNormSmallForwardTest) { + int inputSize = 4; + int nChannels = 2; + + cudaError_t cudaStatus; + + CUDANet::Layers::BatchNorm batchNorm( + inputSize, nChannels, CUDANet::Layers::ActivationType::NONE + ); + + std::vector weights = {0.63508f, 0.64903f}; + std::vector biases = {0.25079f, 0.66841f}; + + batchNorm.setWeights(weights.data()); + batchNorm.setBiases(biases.data()); + + cudaStatus = cudaGetLastError(); + EXPECT_EQ(cudaStatus, cudaSuccess); + + // clang-format off + std::vector input = { + // Channel 0 + 0.38899f, 0.80478f, 0.48836f, 0.97381f, + 0.57508f, 0.60835f, 0.65467f, 0.00168f, + 0.65869f, 0.74235f, 0.17928f, 0.70349f, + 0.15524f, 0.38664f, 0.23411f, 0.7137f, + // Channel 1 + 0.32473f, 0.15698f, 0.314f, 0.60888f, + 0.80268f, 0.99766f, 0.93694f, 0.89237f, + 0.13449f, 0.27367f, 0.53036f, 0.18962f, + 0.57672f, 0.48364f, 0.10863f, 0.0571f + }; + // clang-format on + + std::vector output(input.size()); + + float* d_input; + cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * input.size()); + EXPECT_EQ(cudaStatus, cudaSuccess); + + cudaStatus = cudaMemcpy( + d_input, input.data(), sizeof(float) * input.size(), cudaMemcpyHostToDevice + ); + EXPECT_EQ(cudaStatus, cudaSuccess); + + float* d_output = batchNorm.forward(d_input); + + cudaStatus = cudaMemcpy( + output.data(), d_output, sizeof(float) * output.size(), cudaMemcpyDeviceToHost + ); + EXPECT_EQ(cudaStatus, cudaSuccess); + + std::vector expected = { + -0.06007f, 0.951f, 0.18157f, 1.36202f, + 0.39244f, 0.47335f, 0.58598f, -1.00188f, + 0.59576f, 0.79919f, -0.57001f, 0.70469f, + -0.62847f, -0.06578f, -0.43668f, 0.72952f, + 0.37726f, 0.02088f, 0.35446f, 0.98092f, + 1.39264f, 1.80686f, 1.67786f, 1.58318f, + -0.0269f, 0.26878f, 0.81411f, 0.09022f, + 0.9126f, 0.71485f, -0.08184f, -0.19131f + }; + + // std::cout << "BatchNorm: " << std::endl; + for (int i = 0; i < output.size(); i++) { + EXPECT_EQ(output[i], expected[i]); + // std::cout << output[i] << " "; + } + // std::cout << std::endl; + cudaFree(d_input); + +} \ No newline at end of file diff --git a/tools/batch_norm_test.py b/tools/batch_norm_test.py new file mode 100644 index 0000000..0003b83 --- /dev/null +++ b/tools/batch_norm_test.py @@ -0,0 +1,31 @@ +import torch + +from utils import print_cpp_vector + +batch_norm = torch.nn.BatchNorm2d(2, track_running_stats=False) + +weights = torch.Tensor([0.63508, 0.64903]) +biases= torch.Tensor([0.25079, 0.66841]) + +batch_norm.weight = torch.nn.Parameter(weights) +batch_norm.bias = torch.nn.Parameter(biases) + +input = torch.Tensor([ + # Channel 0 + 0.38899, 0.80478, 0.48836, 0.97381, + 0.57508, 0.60835, 0.65467, 0.00168, + 0.65869, 0.74235, 0.17928, 0.70349, + 0.15524, 0.38664, 0.23411, 0.7137, + # Channel 1 + 0.32473, 0.15698, 0.314, 0.60888, + 0.80268, 0.99766, 0.93694, 0.89237, + 0.13449, 0.27367, 0.53036, 0.18962, + 0.57672, 0.48364, 0.10863, 0.0571 +]).reshape(1, 2, 4, 4) + +output = batch_norm(input) +print_cpp_vector(output.flatten()) + +print(batch_norm.running_mean) +print(batch_norm.running_var) +