mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-05 17:34:21 +00:00
Implement batch norm test
This commit is contained in:
@@ -129,6 +129,7 @@ float *BatchNorm::forward(const float *d_input) {
|
|||||||
&d_mean[i],
|
&d_mean[i],
|
||||||
inputSize * inputSize
|
inputSize * inputSize
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
|
||||||
d_output + i * inputSize * inputSize,
|
d_output + i * inputSize * inputSize,
|
||||||
@@ -136,6 +137,7 @@ float *BatchNorm::forward(const float *d_input) {
|
|||||||
&d_sqrt_var[i],
|
&d_sqrt_var[i],
|
||||||
inputSize * inputSize
|
inputSize * inputSize
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
Kernels::vec_scalar_mul<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::vec_scalar_mul<<<gridSize, BLOCK_SIZE>>>(
|
||||||
d_output + i * inputSize * inputSize,
|
d_output + i * inputSize * inputSize,
|
||||||
@@ -143,6 +145,7 @@ float *BatchNorm::forward(const float *d_input) {
|
|||||||
&d_weights[i],
|
&d_weights[i],
|
||||||
inputSize * inputSize
|
inputSize * inputSize
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
Kernels::vec_scalar_add<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::vec_scalar_add<<<gridSize, BLOCK_SIZE>>>(
|
||||||
d_output + i * inputSize * inputSize,
|
d_output + i * inputSize * inputSize,
|
||||||
@@ -150,6 +153,7 @@ float *BatchNorm::forward(const float *d_input) {
|
|||||||
&d_biases[i],
|
&d_biases[i],
|
||||||
inputSize * inputSize
|
inputSize * inputSize
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
return d_output;
|
return d_output;
|
||||||
|
|||||||
80
test/layers/test_batch_norm.cu
Normal file
80
test/layers/test_batch_norm.cu
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "activation.cuh"
|
||||||
|
#include "batch_norm.cuh"
|
||||||
|
|
||||||
|
TEST(BatchNormLayerTest, BatchNormSmallForwardTest) {
|
||||||
|
int inputSize = 4;
|
||||||
|
int nChannels = 2;
|
||||||
|
|
||||||
|
cudaError_t cudaStatus;
|
||||||
|
|
||||||
|
CUDANet::Layers::BatchNorm batchNorm(
|
||||||
|
inputSize, nChannels, CUDANet::Layers::ActivationType::NONE
|
||||||
|
);
|
||||||
|
|
||||||
|
std::vector<float> weights = {0.63508f, 0.64903f};
|
||||||
|
std::vector<float> biases = {0.25079f, 0.66841f};
|
||||||
|
|
||||||
|
batchNorm.setWeights(weights.data());
|
||||||
|
batchNorm.setBiases(biases.data());
|
||||||
|
|
||||||
|
cudaStatus = cudaGetLastError();
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
std::vector<float> input = {
|
||||||
|
// Channel 0
|
||||||
|
0.38899f, 0.80478f, 0.48836f, 0.97381f,
|
||||||
|
0.57508f, 0.60835f, 0.65467f, 0.00168f,
|
||||||
|
0.65869f, 0.74235f, 0.17928f, 0.70349f,
|
||||||
|
0.15524f, 0.38664f, 0.23411f, 0.7137f,
|
||||||
|
// Channel 1
|
||||||
|
0.32473f, 0.15698f, 0.314f, 0.60888f,
|
||||||
|
0.80268f, 0.99766f, 0.93694f, 0.89237f,
|
||||||
|
0.13449f, 0.27367f, 0.53036f, 0.18962f,
|
||||||
|
0.57672f, 0.48364f, 0.10863f, 0.0571f
|
||||||
|
};
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
std::vector<float> output(input.size());
|
||||||
|
|
||||||
|
float* d_input;
|
||||||
|
cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * input.size());
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
cudaStatus = cudaMemcpy(
|
||||||
|
d_input, input.data(), sizeof(float) * input.size(), cudaMemcpyHostToDevice
|
||||||
|
);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
float* d_output = batchNorm.forward(d_input);
|
||||||
|
|
||||||
|
cudaStatus = cudaMemcpy(
|
||||||
|
output.data(), d_output, sizeof(float) * output.size(), cudaMemcpyDeviceToHost
|
||||||
|
);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
std::vector<float> expected = {
|
||||||
|
-0.06007f, 0.951f, 0.18157f, 1.36202f,
|
||||||
|
0.39244f, 0.47335f, 0.58598f, -1.00188f,
|
||||||
|
0.59576f, 0.79919f, -0.57001f, 0.70469f,
|
||||||
|
-0.62847f, -0.06578f, -0.43668f, 0.72952f,
|
||||||
|
0.37726f, 0.02088f, 0.35446f, 0.98092f,
|
||||||
|
1.39264f, 1.80686f, 1.67786f, 1.58318f,
|
||||||
|
-0.0269f, 0.26878f, 0.81411f, 0.09022f,
|
||||||
|
0.9126f, 0.71485f, -0.08184f, -0.19131f
|
||||||
|
};
|
||||||
|
|
||||||
|
// std::cout << "BatchNorm: " << std::endl;
|
||||||
|
for (int i = 0; i < output.size(); i++) {
|
||||||
|
EXPECT_EQ(output[i], expected[i]);
|
||||||
|
// std::cout << output[i] << " ";
|
||||||
|
}
|
||||||
|
// std::cout << std::endl;
|
||||||
|
cudaFree(d_input);
|
||||||
|
|
||||||
|
}
|
||||||
31
tools/batch_norm_test.py
Normal file
31
tools/batch_norm_test.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from utils import print_cpp_vector
|
||||||
|
|
||||||
|
batch_norm = torch.nn.BatchNorm2d(2, track_running_stats=False)
|
||||||
|
|
||||||
|
weights = torch.Tensor([0.63508, 0.64903])
|
||||||
|
biases= torch.Tensor([0.25079, 0.66841])
|
||||||
|
|
||||||
|
batch_norm.weight = torch.nn.Parameter(weights)
|
||||||
|
batch_norm.bias = torch.nn.Parameter(biases)
|
||||||
|
|
||||||
|
input = torch.Tensor([
|
||||||
|
# Channel 0
|
||||||
|
0.38899, 0.80478, 0.48836, 0.97381,
|
||||||
|
0.57508, 0.60835, 0.65467, 0.00168,
|
||||||
|
0.65869, 0.74235, 0.17928, 0.70349,
|
||||||
|
0.15524, 0.38664, 0.23411, 0.7137,
|
||||||
|
# Channel 1
|
||||||
|
0.32473, 0.15698, 0.314, 0.60888,
|
||||||
|
0.80268, 0.99766, 0.93694, 0.89237,
|
||||||
|
0.13449, 0.27367, 0.53036, 0.18962,
|
||||||
|
0.57672, 0.48364, 0.10863, 0.0571
|
||||||
|
]).reshape(1, 2, 4, 4)
|
||||||
|
|
||||||
|
output = batch_norm(input)
|
||||||
|
print_cpp_vector(output.flatten())
|
||||||
|
|
||||||
|
print(batch_norm.running_mean)
|
||||||
|
print(batch_norm.running_var)
|
||||||
|
|
||||||
Reference in New Issue
Block a user