mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Add running mean and running var to batchnorm
This commit is contained in:
@@ -66,9 +66,7 @@ TEST(MatMulTest, MatVecMulTest) {
|
||||
|
||||
cudaFree(d_matrix);
|
||||
cudaFree(d_vector);
|
||||
cudaFree(d_output);
|
||||
|
||||
|
||||
cudaFree(d_output);
|
||||
}
|
||||
|
||||
TEST(MatMulTest, MaxReduceTest) {
|
||||
@@ -211,4 +209,60 @@ TEST(MatMulTest, SumReduceTest) {
|
||||
|
||||
cudaFree(d_input);
|
||||
cudaFree(d_sum);
|
||||
}
|
||||
|
||||
TEST(MatMulTest, VecScaleTest) {
|
||||
cudaError_t cudaStatus;
|
||||
int len = 1000;
|
||||
float* d_src;
|
||||
float* d_dst;
|
||||
float* d_scale;
|
||||
float* d_epsilon;
|
||||
|
||||
cudaStatus = cudaMalloc((void**)&d_src, sizeof(float) * len);
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
cudaStatus = cudaMalloc((void**)&d_dst, sizeof(float) * len);
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
cudaStatus = cudaMalloc((void**)&d_scale, sizeof(float));
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
cudaStatus = cudaMalloc((void**)&d_epsilon, sizeof(float));
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
std::vector<float> src(len);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
src[i] = static_cast<float>(rand()) / RAND_MAX;
|
||||
}
|
||||
|
||||
float scale = 1.5f;
|
||||
float epsilon = 1e-5f;
|
||||
|
||||
cudaStatus = cudaMemcpy(d_src, src.data(), sizeof(float) * len, cudaMemcpyHostToDevice);
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
cudaStatus = cudaMemcpy(d_scale, &scale, sizeof(float), cudaMemcpyHostToDevice);
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
cudaStatus = cudaMemcpy(d_epsilon, &epsilon, sizeof(float), cudaMemcpyHostToDevice);
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
int grid_size = (len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
CUDANet::Kernels::vec_scale<<<grid_size, BLOCK_SIZE>>>(d_src, d_dst, d_scale, d_epsilon, len);
|
||||
|
||||
cudaStatus = cudaDeviceSynchronize();
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
std::vector<float> dst_gpu(len);
|
||||
cudaStatus = cudaMemcpy(dst_gpu.data(), d_dst, sizeof(float) * len, cudaMemcpyDeviceToHost);
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
float inv_std = 1.0f / std::sqrt(scale + epsilon);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
EXPECT_NEAR(src[i] * inv_std, dst_gpu[i], 1e-5f);
|
||||
}
|
||||
|
||||
cudaFree(d_src);
|
||||
cudaFree(d_dst);
|
||||
cudaFree(d_scale);
|
||||
cudaFree(d_epsilon);
|
||||
}
|
||||
@@ -8,10 +8,14 @@
|
||||
|
||||
class BatchNormLayerTest : public ::testing::Test {
|
||||
protected:
|
||||
shape2d inputSize;
|
||||
shape2d inputSize;
|
||||
int nChannels;
|
||||
std::vector<float> weights;
|
||||
std::vector<float> biases;
|
||||
|
||||
std::vector<float> runningMean;
|
||||
std::vector<float> runningVar;
|
||||
|
||||
std::vector<float> input;
|
||||
std::vector<float> expected;
|
||||
|
||||
@@ -41,6 +45,9 @@ class BatchNormLayerTest : public ::testing::Test {
|
||||
batchNorm->setWeights(weights.data());
|
||||
batchNorm->setBiases(biases.data());
|
||||
|
||||
batchNorm->setRunningMean(runningMean.data());
|
||||
batchNorm->setRunningVar(runningVar.data());
|
||||
|
||||
cudaStatus = cudaGetLastError();
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
@@ -78,6 +85,9 @@ TEST_F(BatchNormLayerTest, BatchNormSmallForwardTest) {
|
||||
weights = {0.63508f, 0.64903f};
|
||||
biases = {0.25079f, 0.66841f};
|
||||
|
||||
runningMean = {0.5f, 0.5f};
|
||||
runningVar = {1.0f, 1.0f};
|
||||
|
||||
// clang-format off
|
||||
input = {
|
||||
// Channel 0
|
||||
@@ -93,12 +103,12 @@ TEST_F(BatchNormLayerTest, BatchNormSmallForwardTest) {
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
expected = {-0.06007f, 0.951f, 0.18157f, 1.36202f, 0.39244f, 0.47335f,
|
||||
0.58598f, -1.00188f, 0.59576f, 0.79919f, -0.57001f, 0.70469f,
|
||||
-0.62847f, -0.06578f, -0.43668f, 0.72952f, 0.37726f, 0.02088f,
|
||||
0.35446f, 0.98092f, 1.39264f, 1.80686f, 1.67786f, 1.58318f,
|
||||
-0.0269f, 0.26878f, 0.81411f, 0.09022f, 0.9126f, 0.71485f,
|
||||
-0.08184f, -0.19131f};
|
||||
expected = {0.18029f, 0.44435f, 0.2434f, 0.5517f, 0.29847f, 0.3196f,
|
||||
0.34902f, -0.06568f, 0.35157f, 0.4047f, 0.04711f, 0.38002f,
|
||||
0.03184f, 0.1788f, 0.08193f, 0.38651f, 0.55466f, 0.44578f,
|
||||
0.54769f, 0.73908f, 0.86486f, 0.9914f, 0.952f, 0.92307f,
|
||||
0.43118f, 0.52152f, 0.68811f, 0.46697f, 0.7182f, 0.65779f,
|
||||
0.4144f, 0.38096f};
|
||||
|
||||
runTest();
|
||||
}
|
||||
@@ -109,6 +119,9 @@ TEST_F(BatchNormLayerTest, BatchNormNonSquareInputTest) {
|
||||
weights = {0.63508f, 0.64903f};
|
||||
biases = {0.25079f, 0.66841f};
|
||||
|
||||
runningMean = {0.5f, 0.5f};
|
||||
runningVar = {1.0f, 1.0f};
|
||||
|
||||
input = {// Channel 0
|
||||
0.38899f, 0.80478f, 0.48836f, 0.97381f, 0.21567f, 0.92312f,
|
||||
0.57508f, 0.60835f, 0.65467f, 0.00168f, 0.31567f, 0.71345f,
|
||||
@@ -121,16 +134,14 @@ TEST_F(BatchNormLayerTest, BatchNormNonSquareInputTest) {
|
||||
0.48364f, 0.10863f, 0.0571f, 0.78934f, 0.67545f
|
||||
};
|
||||
|
||||
expected = {-0.05598f, 0.87495f, 0.1665f, 1.2534f, -0.44404f,
|
||||
1.13991f, 0.36066f, 0.43515f, 0.53886f, -0.92315f,
|
||||
-0.22014f, 0.67047f, 0.54786f, 0.73517f, -0.52552f,
|
||||
0.64817f, -0.63907f, 1.21453f, -0.57934f, -0.06124f,
|
||||
-0.40275f, 0.67103f, -0.32712f, 0.94064f, 0.28344f,
|
||||
-0.08405f, 0.25993f, 0.90592f, 0.07909f, 1.30149f,
|
||||
1.33047f, 1.7576f, 1.62459f, 1.52695f, 0.9135f,
|
||||
1.59436f, -0.13331f, 0.17158f, 0.73391f, -0.01254f,
|
||||
0.57151f, -0.10979f, 0.83546f, 0.63156f, -0.18996f,
|
||||
-0.30285f, 1.30124f, 1.05175f};
|
||||
expected = {0.18029f, 0.44435f, 0.2434f, 0.5517f, 0.07022f, 0.5195f,
|
||||
0.29847f, 0.3196f, 0.34902f, -0.06568f, 0.13373f, 0.38635f,
|
||||
0.35157f, 0.4047f, 0.04711f, 0.38002f, 0.0149f, 0.54067f,
|
||||
0.03184f, 0.1788f, 0.08193f, 0.38651f, 0.10338f, 0.46298f,
|
||||
0.55466f, 0.44578f, 0.54769f, 0.73908f, 0.49411f, 0.85627f,
|
||||
0.86486f, 0.9914f, 0.952f, 0.92307f, 0.74132f, 0.94304f,
|
||||
0.43118f, 0.52152f, 0.68811f, 0.46697f, 0.64f, 0.43815f,
|
||||
0.7182f, 0.65779f, 0.4144f, 0.38096f, 0.8562f, 0.78228f};
|
||||
|
||||
runTest();
|
||||
}
|
||||
Reference in New Issue
Block a user