diff --git a/include/kernels/matmul.cuh b/include/kernels/matmul.cuh index 27a4b24..0c0bfef 100644 --- a/include/kernels/matmul.cuh +++ b/include/kernels/matmul.cuh @@ -141,7 +141,6 @@ __global__ void vec_exp( * @param src Device pointer to source vector * @param dst Device pointer to destination vector * @param len Length of the vector - * @return __global__ */ __global__ void vec_sqrt( const float* __restrict__ src, @@ -149,6 +148,23 @@ __global__ void vec_sqrt( const unsigned int len ); +/** + * @brief Scales the vector by 1/sqrt(scale + epsilon) + * + * @param src Device pointer to source vector + * @param dst Device pointer to destination vector + * @param scale Scale + * @param epsilon Epsilon + * @param len Length of the vector + */ +__global__ void vec_scale( + const float* __restrict__ src, + float* __restrict__ dst, + const float* __restrict__ scale, + const float* epsilon, + const unsigned int len +); + /** * @brief Max reduction kernel * diff --git a/include/layers/batch_norm.cuh b/include/layers/batch_norm.cuh index dad5111..1796c54 100644 --- a/include/layers/batch_norm.cuh +++ b/include/layers/batch_norm.cuh @@ -50,6 +50,20 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer { */ std::vector getBiases(); + /** + * @brief Set the Running Mean + * + * @param running_mean_input + */ + void setRunningMean(const float* running_mean_input); + + /** + * @brief Set the Running Var + * + * @param running_mean_input + */ + void setRunningVar(const float* running_mean_input); + /** * @brief Get output size * @@ -75,9 +89,8 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer { float* d_output; - float* d_mean; - float* d_mean_sub; - float* d_sqrt_var; + float* d_running_mean; + float* d_running_var; float* d_length; float* d_epsilon; @@ -88,8 +101,8 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer { std::vector weights; std::vector biases; - std::vector mean; - std::vector sqrt_var; + std::vector running_mean; + std::vector running_var; Activation* activation; @@ -109,13 +122,13 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer { * @brief Initialize mean of the batchnorm layer with zeros * */ - void initializeMean(); + void initializeRunningMean(); /** * @brief Initialize sqrt of variance of the batchnorm layer with ones * */ - void initializeSqrtVar(); + void initializeRunningVar(); /** * @brief Copy weights and biases to the device diff --git a/src/kernels/matmul.cu b/src/kernels/matmul.cu index e4ef8da..31c68d5 100644 --- a/src/kernels/matmul.cu +++ b/src/kernels/matmul.cu @@ -140,6 +140,19 @@ __global__ void Kernels::vec_sqrt( } } +__global__ void Kernels::vec_scale( + const float* __restrict__ src, + float* __restrict__ dst, + const float* __restrict__ scale, + const float* epsilon, + const unsigned int len +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < len) { + float inv_std = rsqrtf(*scale + *epsilon); + dst[idx] = src[idx] * inv_std; + } +} __global__ void Kernels::max_reduce( const float* __restrict__ d_vector, diff --git a/src/layers/batch_norm.cu b/src/layers/batch_norm.cu index cea7477..9b82606 100644 --- a/src/layers/batch_norm.cu +++ b/src/layers/batch_norm.cu @@ -26,19 +26,14 @@ BatchNorm2d::BatchNorm2d( sizeof(float) * inputSize.first * inputSize.second * inputChannels )); - d_mean = nullptr; + d_running_mean = nullptr; CUDA_CHECK(cudaMalloc( - (void **)&d_mean, sizeof(float) * inputSize.first * inputSize.second + (void **)&d_running_mean, sizeof(float) * inputChannels )); - d_mean_sub = nullptr; + d_running_var = nullptr; CUDA_CHECK(cudaMalloc( - (void **)&d_mean_sub, sizeof(float) * inputSize.first * inputSize.second - )); - - d_sqrt_var = nullptr; - CUDA_CHECK(cudaMalloc( - (void **)&d_sqrt_var, sizeof(float) * inputSize.first * inputSize.second + (void **)&d_running_var, sizeof(float) * inputChannels )); d_weights = nullptr; @@ -63,8 +58,13 @@ BatchNorm2d::BatchNorm2d( weights.resize(inputChannels); biases.resize(inputChannels); + running_mean.resize(inputChannels); + running_var.resize(inputChannels); + initializeWeights(); initializeBiases(); + initializeRunningMean(); + initializeRunningVar(); toCuda(); @@ -74,9 +74,8 @@ BatchNorm2d::BatchNorm2d( BatchNorm2d::~BatchNorm2d() { cudaFree(d_output); - cudaFree(d_mean); - cudaFree(d_mean_sub); - cudaFree(d_sqrt_var); + cudaFree(d_running_mean); + cudaFree(d_running_var); cudaFree(d_weights); cudaFree(d_biases); cudaFree(d_length); @@ -91,6 +90,14 @@ void BatchNorm2d::initializeBiases() { std::fill(biases.begin(), biases.end(), 0.0f); } +void BatchNorm2d::initializeRunningMean() { + std::fill(running_mean.begin(), running_mean.end(), 0.0f); +} + +void BatchNorm2d::initializeRunningVar() { + std::fill(running_var.begin(), running_var.end(), 1.0f); +} + void BatchNorm2d::setWeights(const float *weights_input) { std::copy(weights_input, weights_input + weights.size(), weights.begin()); toCuda(); @@ -109,6 +116,16 @@ std::vector BatchNorm2d::getBiases() { return biases; } +void BatchNorm2d::setRunningMean(const float* running_mean_input) { + std::copy(running_mean_input, running_mean_input + inputChannels, running_mean.begin()); + toCuda(); +} + +void BatchNorm2d::setRunningVar(const float* running_var_input) { + std::copy(running_var_input, running_var_input + inputChannels, running_var.begin()); + toCuda(); +} + void BatchNorm2d::toCuda() { CUDA_CHECK(cudaMemcpy( d_weights, weights.data(), sizeof(float) * inputChannels, @@ -118,6 +135,14 @@ void BatchNorm2d::toCuda() { d_biases, biases.data(), sizeof(float) * inputChannels, cudaMemcpyHostToDevice )); + CUDA_CHECK(cudaMemcpy( + d_running_mean, running_mean.data(), sizeof(float) * inputChannels, + cudaMemcpyHostToDevice + )); + CUDA_CHECK(cudaMemcpy( + d_running_var, running_var.data(), sizeof(float) * inputChannels, + cudaMemcpyHostToDevice + )); } int BatchNorm2d::getInputSize() { @@ -135,48 +160,30 @@ shape2d BatchNorm2d::getOutputDims() { float *BatchNorm2d::forward(const float *d_input) { // Compute per-channel batch normalization for (int i = 0; i < inputChannels; i++) { - // Compute mean - Utils::mean( - d_input + i * inputSize.first * inputSize.second, d_mean, d_length, - inputSize.first * inputSize.second - ); // Subtract mean from input Kernels::vec_scalar_sub<<>>( - d_input + i * inputSize.first * inputSize.second, d_mean_sub, - &d_mean[0], inputSize.first * inputSize.second + d_input + i * inputSize.first * inputSize.second, + d_output + i * inputSize.first * inputSize.second, + &d_running_mean[i], inputSize.first * inputSize.second ); CUDA_CHECK(cudaGetLastError()); - // Compute variance - Utils::var( - d_mean_sub, d_sqrt_var, d_length, inputSize.first * inputSize.second - ); - - // Add epsilon to variance to avoid division by zero - Kernels::vec_scalar_add<<>>( - d_sqrt_var, d_sqrt_var, &d_epsilon[0], + // Divide by sqrt(running_var + epsilon) + Kernels::vec_scale<<>>( + d_output + i * inputSize.first * inputSize.second, + d_output + i * inputSize.first * inputSize.second, + &d_running_var[i], + d_epsilon, inputSize.first * inputSize.second ); CUDA_CHECK(cudaGetLastError()); - // Compute squared root of variance - Kernels::vec_sqrt<<>>( - d_sqrt_var, d_sqrt_var, inputSize.first * inputSize.second - ); - CUDA_CHECK(cudaGetLastError()); - - // Divide by squared root of variance - Kernels::vec_scalar_div<<>>( - d_mean_sub, d_output + i * inputSize.first * inputSize.second, - &d_sqrt_var[0], inputSize.first * inputSize.second - ); - CUDA_CHECK(cudaGetLastError()); - // Multiply by weights Kernels::vec_scalar_mul<<>>( d_output + i * inputSize.first * inputSize.second, - d_output + i * inputSize.first * inputSize.second, &d_weights[i], + d_output + i * inputSize.first * inputSize.second, + &d_weights[i], inputSize.first * inputSize.second ); CUDA_CHECK(cudaGetLastError()); @@ -184,7 +191,8 @@ float *BatchNorm2d::forward(const float *d_input) { // Add biases Kernels::vec_scalar_add<<>>( d_output + i * inputSize.first * inputSize.second, - d_output + i * inputSize.first * inputSize.second, &d_biases[i], + d_output + i * inputSize.first * inputSize.second, + &d_biases[i], inputSize.first * inputSize.second ); CUDA_CHECK(cudaGetLastError()); diff --git a/test/kernels/test_matmul.cu b/test/kernels/test_matmul.cu index 3941be7..7d72726 100644 --- a/test/kernels/test_matmul.cu +++ b/test/kernels/test_matmul.cu @@ -66,9 +66,7 @@ TEST(MatMulTest, MatVecMulTest) { cudaFree(d_matrix); cudaFree(d_vector); - cudaFree(d_output); - - + cudaFree(d_output); } TEST(MatMulTest, MaxReduceTest) { @@ -211,4 +209,60 @@ TEST(MatMulTest, SumReduceTest) { cudaFree(d_input); cudaFree(d_sum); +} + +TEST(MatMulTest, VecScaleTest) { + cudaError_t cudaStatus; + int len = 1000; + float* d_src; + float* d_dst; + float* d_scale; + float* d_epsilon; + + cudaStatus = cudaMalloc((void**)&d_src, sizeof(float) * len); + EXPECT_EQ(cudaStatus, cudaSuccess); + + cudaStatus = cudaMalloc((void**)&d_dst, sizeof(float) * len); + EXPECT_EQ(cudaStatus, cudaSuccess); + + cudaStatus = cudaMalloc((void**)&d_scale, sizeof(float)); + EXPECT_EQ(cudaStatus, cudaSuccess); + + cudaStatus = cudaMalloc((void**)&d_epsilon, sizeof(float)); + EXPECT_EQ(cudaStatus, cudaSuccess); + + std::vector src(len); + for (int i = 0; i < len; ++i) { + src[i] = static_cast(rand()) / RAND_MAX; + } + + float scale = 1.5f; + float epsilon = 1e-5f; + + cudaStatus = cudaMemcpy(d_src, src.data(), sizeof(float) * len, cudaMemcpyHostToDevice); + EXPECT_EQ(cudaStatus, cudaSuccess); + cudaStatus = cudaMemcpy(d_scale, &scale, sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cudaStatus, cudaSuccess); + cudaStatus = cudaMemcpy(d_epsilon, &epsilon, sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cudaStatus, cudaSuccess); + + int grid_size = (len + BLOCK_SIZE - 1) / BLOCK_SIZE; + CUDANet::Kernels::vec_scale<<>>(d_src, d_dst, d_scale, d_epsilon, len); + + cudaStatus = cudaDeviceSynchronize(); + EXPECT_EQ(cudaStatus, cudaSuccess); + + std::vector dst_gpu(len); + cudaStatus = cudaMemcpy(dst_gpu.data(), d_dst, sizeof(float) * len, cudaMemcpyDeviceToHost); + EXPECT_EQ(cudaStatus, cudaSuccess); + + float inv_std = 1.0f / std::sqrt(scale + epsilon); + for (int i = 0; i < len; ++i) { + EXPECT_NEAR(src[i] * inv_std, dst_gpu[i], 1e-5f); + } + + cudaFree(d_src); + cudaFree(d_dst); + cudaFree(d_scale); + cudaFree(d_epsilon); } \ No newline at end of file diff --git a/test/layers/test_batch_norm.cu b/test/layers/test_batch_norm.cu index 10acb0d..904eb44 100644 --- a/test/layers/test_batch_norm.cu +++ b/test/layers/test_batch_norm.cu @@ -8,10 +8,14 @@ class BatchNormLayerTest : public ::testing::Test { protected: - shape2d inputSize; + shape2d inputSize; int nChannels; std::vector weights; std::vector biases; + + std::vector runningMean; + std::vector runningVar; + std::vector input; std::vector expected; @@ -41,6 +45,9 @@ class BatchNormLayerTest : public ::testing::Test { batchNorm->setWeights(weights.data()); batchNorm->setBiases(biases.data()); + batchNorm->setRunningMean(runningMean.data()); + batchNorm->setRunningVar(runningVar.data()); + cudaStatus = cudaGetLastError(); EXPECT_EQ(cudaStatus, cudaSuccess); @@ -78,6 +85,9 @@ TEST_F(BatchNormLayerTest, BatchNormSmallForwardTest) { weights = {0.63508f, 0.64903f}; biases = {0.25079f, 0.66841f}; + runningMean = {0.5f, 0.5f}; + runningVar = {1.0f, 1.0f}; + // clang-format off input = { // Channel 0 @@ -93,12 +103,12 @@ TEST_F(BatchNormLayerTest, BatchNormSmallForwardTest) { }; // clang-format on - expected = {-0.06007f, 0.951f, 0.18157f, 1.36202f, 0.39244f, 0.47335f, - 0.58598f, -1.00188f, 0.59576f, 0.79919f, -0.57001f, 0.70469f, - -0.62847f, -0.06578f, -0.43668f, 0.72952f, 0.37726f, 0.02088f, - 0.35446f, 0.98092f, 1.39264f, 1.80686f, 1.67786f, 1.58318f, - -0.0269f, 0.26878f, 0.81411f, 0.09022f, 0.9126f, 0.71485f, - -0.08184f, -0.19131f}; + expected = {0.18029f, 0.44435f, 0.2434f, 0.5517f, 0.29847f, 0.3196f, + 0.34902f, -0.06568f, 0.35157f, 0.4047f, 0.04711f, 0.38002f, + 0.03184f, 0.1788f, 0.08193f, 0.38651f, 0.55466f, 0.44578f, + 0.54769f, 0.73908f, 0.86486f, 0.9914f, 0.952f, 0.92307f, + 0.43118f, 0.52152f, 0.68811f, 0.46697f, 0.7182f, 0.65779f, + 0.4144f, 0.38096f}; runTest(); } @@ -109,6 +119,9 @@ TEST_F(BatchNormLayerTest, BatchNormNonSquareInputTest) { weights = {0.63508f, 0.64903f}; biases = {0.25079f, 0.66841f}; + runningMean = {0.5f, 0.5f}; + runningVar = {1.0f, 1.0f}; + input = {// Channel 0 0.38899f, 0.80478f, 0.48836f, 0.97381f, 0.21567f, 0.92312f, 0.57508f, 0.60835f, 0.65467f, 0.00168f, 0.31567f, 0.71345f, @@ -121,16 +134,14 @@ TEST_F(BatchNormLayerTest, BatchNormNonSquareInputTest) { 0.48364f, 0.10863f, 0.0571f, 0.78934f, 0.67545f }; - expected = {-0.05598f, 0.87495f, 0.1665f, 1.2534f, -0.44404f, - 1.13991f, 0.36066f, 0.43515f, 0.53886f, -0.92315f, - -0.22014f, 0.67047f, 0.54786f, 0.73517f, -0.52552f, - 0.64817f, -0.63907f, 1.21453f, -0.57934f, -0.06124f, - -0.40275f, 0.67103f, -0.32712f, 0.94064f, 0.28344f, - -0.08405f, 0.25993f, 0.90592f, 0.07909f, 1.30149f, - 1.33047f, 1.7576f, 1.62459f, 1.52695f, 0.9135f, - 1.59436f, -0.13331f, 0.17158f, 0.73391f, -0.01254f, - 0.57151f, -0.10979f, 0.83546f, 0.63156f, -0.18996f, - -0.30285f, 1.30124f, 1.05175f}; + expected = {0.18029f, 0.44435f, 0.2434f, 0.5517f, 0.07022f, 0.5195f, + 0.29847f, 0.3196f, 0.34902f, -0.06568f, 0.13373f, 0.38635f, + 0.35157f, 0.4047f, 0.04711f, 0.38002f, 0.0149f, 0.54067f, + 0.03184f, 0.1788f, 0.08193f, 0.38651f, 0.10338f, 0.46298f, + 0.55466f, 0.44578f, 0.54769f, 0.73908f, 0.49411f, 0.85627f, + 0.86486f, 0.9914f, 0.952f, 0.92307f, 0.74132f, 0.94304f, + 0.43118f, 0.52152f, 0.68811f, 0.46697f, 0.64f, 0.43815f, + 0.7182f, 0.65779f, 0.4144f, 0.38096f, 0.8562f, 0.78228f}; runTest(); } \ No newline at end of file diff --git a/tools/batch_norm_test.py b/tools/batch_norm_test.py index de8db7a..e00fc2e 100644 --- a/tools/batch_norm_test.py +++ b/tools/batch_norm_test.py @@ -5,7 +5,7 @@ from utils import print_cpp_vector def gen_batch_norm_test_result(input): - batch_norm = torch.nn.BatchNorm2d(2, track_running_stats=False) + batch_norm = torch.nn.BatchNorm2d(2, track_running_stats=True) weights = torch.Tensor([0.63508, 0.64903]) biases = torch.Tensor([0.25079, 0.66841]) @@ -13,7 +13,13 @@ def gen_batch_norm_test_result(input): batch_norm.weight = torch.nn.Parameter(weights) batch_norm.bias = torch.nn.Parameter(biases) + batch_norm.running_mean = torch.Tensor([0.5, 0.5]) + batch_norm.running_var = torch.Tensor([1.0, 1.0]) + + batch_norm.eval() + output = batch_norm(input) + print_cpp_vector(output.flatten()) diff --git a/tools/utils.py b/tools/utils.py index e965d0b..bbc5f33 100644 --- a/tools/utils.py +++ b/tools/utils.py @@ -35,6 +35,19 @@ def export_model_weights(model: torch.nn.Module, filename): tensor_data += tensor_bytes + # print(model.named_buffers) + + # Add buffers (for running_mean and running_var) + for name, buf in model.named_buffers(): + if "running_mean" not in name and "running_var" not in name: + continue + + tensor_bytes = buf.type(torch.float32).detach().numpy().tobytes() + tensor_size = buf.numel() + header += f"{name},{tensor_size},{offset}\n" + offset += len(tensor_bytes) + tensor_data += tensor_bytes + f.seek(0) f.write(struct.pack("H", version)) f.write(struct.pack("Q", len(header)))