mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-05 17:34:21 +00:00
Add running mean and running var to batchnorm
This commit is contained in:
@@ -141,7 +141,6 @@ __global__ void vec_exp(
|
|||||||
* @param src Device pointer to source vector
|
* @param src Device pointer to source vector
|
||||||
* @param dst Device pointer to destination vector
|
* @param dst Device pointer to destination vector
|
||||||
* @param len Length of the vector
|
* @param len Length of the vector
|
||||||
* @return __global__
|
|
||||||
*/
|
*/
|
||||||
__global__ void vec_sqrt(
|
__global__ void vec_sqrt(
|
||||||
const float* __restrict__ src,
|
const float* __restrict__ src,
|
||||||
@@ -149,6 +148,23 @@ __global__ void vec_sqrt(
|
|||||||
const unsigned int len
|
const unsigned int len
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Scales the vector by 1/sqrt(scale + epsilon)
|
||||||
|
*
|
||||||
|
* @param src Device pointer to source vector
|
||||||
|
* @param dst Device pointer to destination vector
|
||||||
|
* @param scale Scale
|
||||||
|
* @param epsilon Epsilon
|
||||||
|
* @param len Length of the vector
|
||||||
|
*/
|
||||||
|
__global__ void vec_scale(
|
||||||
|
const float* __restrict__ src,
|
||||||
|
float* __restrict__ dst,
|
||||||
|
const float* __restrict__ scale,
|
||||||
|
const float* epsilon,
|
||||||
|
const unsigned int len
|
||||||
|
);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Max reduction kernel
|
* @brief Max reduction kernel
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -50,6 +50,20 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer {
|
|||||||
*/
|
*/
|
||||||
std::vector<float> getBiases();
|
std::vector<float> getBiases();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Set the Running Mean
|
||||||
|
*
|
||||||
|
* @param running_mean_input
|
||||||
|
*/
|
||||||
|
void setRunningMean(const float* running_mean_input);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Set the Running Var
|
||||||
|
*
|
||||||
|
* @param running_mean_input
|
||||||
|
*/
|
||||||
|
void setRunningVar(const float* running_mean_input);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Get output size
|
* @brief Get output size
|
||||||
*
|
*
|
||||||
@@ -75,9 +89,8 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer {
|
|||||||
|
|
||||||
float* d_output;
|
float* d_output;
|
||||||
|
|
||||||
float* d_mean;
|
float* d_running_mean;
|
||||||
float* d_mean_sub;
|
float* d_running_var;
|
||||||
float* d_sqrt_var;
|
|
||||||
|
|
||||||
float* d_length;
|
float* d_length;
|
||||||
float* d_epsilon;
|
float* d_epsilon;
|
||||||
@@ -88,8 +101,8 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer {
|
|||||||
std::vector<float> weights;
|
std::vector<float> weights;
|
||||||
std::vector<float> biases;
|
std::vector<float> biases;
|
||||||
|
|
||||||
std::vector<float> mean;
|
std::vector<float> running_mean;
|
||||||
std::vector<float> sqrt_var;
|
std::vector<float> running_var;
|
||||||
|
|
||||||
Activation* activation;
|
Activation* activation;
|
||||||
|
|
||||||
@@ -109,13 +122,13 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer {
|
|||||||
* @brief Initialize mean of the batchnorm layer with zeros
|
* @brief Initialize mean of the batchnorm layer with zeros
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
void initializeMean();
|
void initializeRunningMean();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Initialize sqrt of variance of the batchnorm layer with ones
|
* @brief Initialize sqrt of variance of the batchnorm layer with ones
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
void initializeSqrtVar();
|
void initializeRunningVar();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Copy weights and biases to the device
|
* @brief Copy weights and biases to the device
|
||||||
|
|||||||
@@ -140,6 +140,19 @@ __global__ void Kernels::vec_sqrt(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__global__ void Kernels::vec_scale(
|
||||||
|
const float* __restrict__ src,
|
||||||
|
float* __restrict__ dst,
|
||||||
|
const float* __restrict__ scale,
|
||||||
|
const float* epsilon,
|
||||||
|
const unsigned int len
|
||||||
|
) {
|
||||||
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (idx < len) {
|
||||||
|
float inv_std = rsqrtf(*scale + *epsilon);
|
||||||
|
dst[idx] = src[idx] * inv_std;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
__global__ void Kernels::max_reduce(
|
__global__ void Kernels::max_reduce(
|
||||||
const float* __restrict__ d_vector,
|
const float* __restrict__ d_vector,
|
||||||
|
|||||||
@@ -26,19 +26,14 @@ BatchNorm2d::BatchNorm2d(
|
|||||||
sizeof(float) * inputSize.first * inputSize.second * inputChannels
|
sizeof(float) * inputSize.first * inputSize.second * inputChannels
|
||||||
));
|
));
|
||||||
|
|
||||||
d_mean = nullptr;
|
d_running_mean = nullptr;
|
||||||
CUDA_CHECK(cudaMalloc(
|
CUDA_CHECK(cudaMalloc(
|
||||||
(void **)&d_mean, sizeof(float) * inputSize.first * inputSize.second
|
(void **)&d_running_mean, sizeof(float) * inputChannels
|
||||||
));
|
));
|
||||||
|
|
||||||
d_mean_sub = nullptr;
|
d_running_var = nullptr;
|
||||||
CUDA_CHECK(cudaMalloc(
|
CUDA_CHECK(cudaMalloc(
|
||||||
(void **)&d_mean_sub, sizeof(float) * inputSize.first * inputSize.second
|
(void **)&d_running_var, sizeof(float) * inputChannels
|
||||||
));
|
|
||||||
|
|
||||||
d_sqrt_var = nullptr;
|
|
||||||
CUDA_CHECK(cudaMalloc(
|
|
||||||
(void **)&d_sqrt_var, sizeof(float) * inputSize.first * inputSize.second
|
|
||||||
));
|
));
|
||||||
|
|
||||||
d_weights = nullptr;
|
d_weights = nullptr;
|
||||||
@@ -63,8 +58,13 @@ BatchNorm2d::BatchNorm2d(
|
|||||||
weights.resize(inputChannels);
|
weights.resize(inputChannels);
|
||||||
biases.resize(inputChannels);
|
biases.resize(inputChannels);
|
||||||
|
|
||||||
|
running_mean.resize(inputChannels);
|
||||||
|
running_var.resize(inputChannels);
|
||||||
|
|
||||||
initializeWeights();
|
initializeWeights();
|
||||||
initializeBiases();
|
initializeBiases();
|
||||||
|
initializeRunningMean();
|
||||||
|
initializeRunningVar();
|
||||||
|
|
||||||
toCuda();
|
toCuda();
|
||||||
|
|
||||||
@@ -74,9 +74,8 @@ BatchNorm2d::BatchNorm2d(
|
|||||||
|
|
||||||
BatchNorm2d::~BatchNorm2d() {
|
BatchNorm2d::~BatchNorm2d() {
|
||||||
cudaFree(d_output);
|
cudaFree(d_output);
|
||||||
cudaFree(d_mean);
|
cudaFree(d_running_mean);
|
||||||
cudaFree(d_mean_sub);
|
cudaFree(d_running_var);
|
||||||
cudaFree(d_sqrt_var);
|
|
||||||
cudaFree(d_weights);
|
cudaFree(d_weights);
|
||||||
cudaFree(d_biases);
|
cudaFree(d_biases);
|
||||||
cudaFree(d_length);
|
cudaFree(d_length);
|
||||||
@@ -91,6 +90,14 @@ void BatchNorm2d::initializeBiases() {
|
|||||||
std::fill(biases.begin(), biases.end(), 0.0f);
|
std::fill(biases.begin(), biases.end(), 0.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void BatchNorm2d::initializeRunningMean() {
|
||||||
|
std::fill(running_mean.begin(), running_mean.end(), 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void BatchNorm2d::initializeRunningVar() {
|
||||||
|
std::fill(running_var.begin(), running_var.end(), 1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
void BatchNorm2d::setWeights(const float *weights_input) {
|
void BatchNorm2d::setWeights(const float *weights_input) {
|
||||||
std::copy(weights_input, weights_input + weights.size(), weights.begin());
|
std::copy(weights_input, weights_input + weights.size(), weights.begin());
|
||||||
toCuda();
|
toCuda();
|
||||||
@@ -109,6 +116,16 @@ std::vector<float> BatchNorm2d::getBiases() {
|
|||||||
return biases;
|
return biases;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void BatchNorm2d::setRunningMean(const float* running_mean_input) {
|
||||||
|
std::copy(running_mean_input, running_mean_input + inputChannels, running_mean.begin());
|
||||||
|
toCuda();
|
||||||
|
}
|
||||||
|
|
||||||
|
void BatchNorm2d::setRunningVar(const float* running_var_input) {
|
||||||
|
std::copy(running_var_input, running_var_input + inputChannels, running_var.begin());
|
||||||
|
toCuda();
|
||||||
|
}
|
||||||
|
|
||||||
void BatchNorm2d::toCuda() {
|
void BatchNorm2d::toCuda() {
|
||||||
CUDA_CHECK(cudaMemcpy(
|
CUDA_CHECK(cudaMemcpy(
|
||||||
d_weights, weights.data(), sizeof(float) * inputChannels,
|
d_weights, weights.data(), sizeof(float) * inputChannels,
|
||||||
@@ -118,6 +135,14 @@ void BatchNorm2d::toCuda() {
|
|||||||
d_biases, biases.data(), sizeof(float) * inputChannels,
|
d_biases, biases.data(), sizeof(float) * inputChannels,
|
||||||
cudaMemcpyHostToDevice
|
cudaMemcpyHostToDevice
|
||||||
));
|
));
|
||||||
|
CUDA_CHECK(cudaMemcpy(
|
||||||
|
d_running_mean, running_mean.data(), sizeof(float) * inputChannels,
|
||||||
|
cudaMemcpyHostToDevice
|
||||||
|
));
|
||||||
|
CUDA_CHECK(cudaMemcpy(
|
||||||
|
d_running_var, running_var.data(), sizeof(float) * inputChannels,
|
||||||
|
cudaMemcpyHostToDevice
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
int BatchNorm2d::getInputSize() {
|
int BatchNorm2d::getInputSize() {
|
||||||
@@ -135,48 +160,30 @@ shape2d BatchNorm2d::getOutputDims() {
|
|||||||
float *BatchNorm2d::forward(const float *d_input) {
|
float *BatchNorm2d::forward(const float *d_input) {
|
||||||
// Compute per-channel batch normalization
|
// Compute per-channel batch normalization
|
||||||
for (int i = 0; i < inputChannels; i++) {
|
for (int i = 0; i < inputChannels; i++) {
|
||||||
// Compute mean
|
|
||||||
Utils::mean(
|
|
||||||
d_input + i * inputSize.first * inputSize.second, d_mean, d_length,
|
|
||||||
inputSize.first * inputSize.second
|
|
||||||
);
|
|
||||||
|
|
||||||
// Subtract mean from input
|
// Subtract mean from input
|
||||||
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
|
||||||
d_input + i * inputSize.first * inputSize.second, d_mean_sub,
|
d_input + i * inputSize.first * inputSize.second,
|
||||||
&d_mean[0], inputSize.first * inputSize.second
|
d_output + i * inputSize.first * inputSize.second,
|
||||||
|
&d_running_mean[i], inputSize.first * inputSize.second
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
// Compute variance
|
// Divide by sqrt(running_var + epsilon)
|
||||||
Utils::var(
|
Kernels::vec_scale<<<gridSize, BLOCK_SIZE>>>(
|
||||||
d_mean_sub, d_sqrt_var, d_length, inputSize.first * inputSize.second
|
d_output + i * inputSize.first * inputSize.second,
|
||||||
);
|
d_output + i * inputSize.first * inputSize.second,
|
||||||
|
&d_running_var[i],
|
||||||
// Add epsilon to variance to avoid division by zero
|
d_epsilon,
|
||||||
Kernels::vec_scalar_add<<<gridSize, BLOCK_SIZE>>>(
|
|
||||||
d_sqrt_var, d_sqrt_var, &d_epsilon[0],
|
|
||||||
inputSize.first * inputSize.second
|
inputSize.first * inputSize.second
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
// Compute squared root of variance
|
|
||||||
Kernels::vec_sqrt<<<gridSize, BLOCK_SIZE>>>(
|
|
||||||
d_sqrt_var, d_sqrt_var, inputSize.first * inputSize.second
|
|
||||||
);
|
|
||||||
CUDA_CHECK(cudaGetLastError());
|
|
||||||
|
|
||||||
// Divide by squared root of variance
|
|
||||||
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
|
|
||||||
d_mean_sub, d_output + i * inputSize.first * inputSize.second,
|
|
||||||
&d_sqrt_var[0], inputSize.first * inputSize.second
|
|
||||||
);
|
|
||||||
CUDA_CHECK(cudaGetLastError());
|
|
||||||
|
|
||||||
// Multiply by weights
|
// Multiply by weights
|
||||||
Kernels::vec_scalar_mul<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::vec_scalar_mul<<<gridSize, BLOCK_SIZE>>>(
|
||||||
d_output + i * inputSize.first * inputSize.second,
|
d_output + i * inputSize.first * inputSize.second,
|
||||||
d_output + i * inputSize.first * inputSize.second, &d_weights[i],
|
d_output + i * inputSize.first * inputSize.second,
|
||||||
|
&d_weights[i],
|
||||||
inputSize.first * inputSize.second
|
inputSize.first * inputSize.second
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
@@ -184,7 +191,8 @@ float *BatchNorm2d::forward(const float *d_input) {
|
|||||||
// Add biases
|
// Add biases
|
||||||
Kernels::vec_scalar_add<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::vec_scalar_add<<<gridSize, BLOCK_SIZE>>>(
|
||||||
d_output + i * inputSize.first * inputSize.second,
|
d_output + i * inputSize.first * inputSize.second,
|
||||||
d_output + i * inputSize.first * inputSize.second, &d_biases[i],
|
d_output + i * inputSize.first * inputSize.second,
|
||||||
|
&d_biases[i],
|
||||||
inputSize.first * inputSize.second
|
inputSize.first * inputSize.second
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|||||||
@@ -66,9 +66,7 @@ TEST(MatMulTest, MatVecMulTest) {
|
|||||||
|
|
||||||
cudaFree(d_matrix);
|
cudaFree(d_matrix);
|
||||||
cudaFree(d_vector);
|
cudaFree(d_vector);
|
||||||
cudaFree(d_output);
|
cudaFree(d_output);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(MatMulTest, MaxReduceTest) {
|
TEST(MatMulTest, MaxReduceTest) {
|
||||||
@@ -211,4 +209,60 @@ TEST(MatMulTest, SumReduceTest) {
|
|||||||
|
|
||||||
cudaFree(d_input);
|
cudaFree(d_input);
|
||||||
cudaFree(d_sum);
|
cudaFree(d_sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MatMulTest, VecScaleTest) {
|
||||||
|
cudaError_t cudaStatus;
|
||||||
|
int len = 1000;
|
||||||
|
float* d_src;
|
||||||
|
float* d_dst;
|
||||||
|
float* d_scale;
|
||||||
|
float* d_epsilon;
|
||||||
|
|
||||||
|
cudaStatus = cudaMalloc((void**)&d_src, sizeof(float) * len);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
cudaStatus = cudaMalloc((void**)&d_dst, sizeof(float) * len);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
cudaStatus = cudaMalloc((void**)&d_scale, sizeof(float));
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
cudaStatus = cudaMalloc((void**)&d_epsilon, sizeof(float));
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
std::vector<float> src(len);
|
||||||
|
for (int i = 0; i < len; ++i) {
|
||||||
|
src[i] = static_cast<float>(rand()) / RAND_MAX;
|
||||||
|
}
|
||||||
|
|
||||||
|
float scale = 1.5f;
|
||||||
|
float epsilon = 1e-5f;
|
||||||
|
|
||||||
|
cudaStatus = cudaMemcpy(d_src, src.data(), sizeof(float) * len, cudaMemcpyHostToDevice);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
cudaStatus = cudaMemcpy(d_scale, &scale, sizeof(float), cudaMemcpyHostToDevice);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
cudaStatus = cudaMemcpy(d_epsilon, &epsilon, sizeof(float), cudaMemcpyHostToDevice);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
int grid_size = (len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
|
CUDANet::Kernels::vec_scale<<<grid_size, BLOCK_SIZE>>>(d_src, d_dst, d_scale, d_epsilon, len);
|
||||||
|
|
||||||
|
cudaStatus = cudaDeviceSynchronize();
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
std::vector<float> dst_gpu(len);
|
||||||
|
cudaStatus = cudaMemcpy(dst_gpu.data(), d_dst, sizeof(float) * len, cudaMemcpyDeviceToHost);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
float inv_std = 1.0f / std::sqrt(scale + epsilon);
|
||||||
|
for (int i = 0; i < len; ++i) {
|
||||||
|
EXPECT_NEAR(src[i] * inv_std, dst_gpu[i], 1e-5f);
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaFree(d_src);
|
||||||
|
cudaFree(d_dst);
|
||||||
|
cudaFree(d_scale);
|
||||||
|
cudaFree(d_epsilon);
|
||||||
}
|
}
|
||||||
@@ -8,10 +8,14 @@
|
|||||||
|
|
||||||
class BatchNormLayerTest : public ::testing::Test {
|
class BatchNormLayerTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
shape2d inputSize;
|
shape2d inputSize;
|
||||||
int nChannels;
|
int nChannels;
|
||||||
std::vector<float> weights;
|
std::vector<float> weights;
|
||||||
std::vector<float> biases;
|
std::vector<float> biases;
|
||||||
|
|
||||||
|
std::vector<float> runningMean;
|
||||||
|
std::vector<float> runningVar;
|
||||||
|
|
||||||
std::vector<float> input;
|
std::vector<float> input;
|
||||||
std::vector<float> expected;
|
std::vector<float> expected;
|
||||||
|
|
||||||
@@ -41,6 +45,9 @@ class BatchNormLayerTest : public ::testing::Test {
|
|||||||
batchNorm->setWeights(weights.data());
|
batchNorm->setWeights(weights.data());
|
||||||
batchNorm->setBiases(biases.data());
|
batchNorm->setBiases(biases.data());
|
||||||
|
|
||||||
|
batchNorm->setRunningMean(runningMean.data());
|
||||||
|
batchNorm->setRunningVar(runningVar.data());
|
||||||
|
|
||||||
cudaStatus = cudaGetLastError();
|
cudaStatus = cudaGetLastError();
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
@@ -78,6 +85,9 @@ TEST_F(BatchNormLayerTest, BatchNormSmallForwardTest) {
|
|||||||
weights = {0.63508f, 0.64903f};
|
weights = {0.63508f, 0.64903f};
|
||||||
biases = {0.25079f, 0.66841f};
|
biases = {0.25079f, 0.66841f};
|
||||||
|
|
||||||
|
runningMean = {0.5f, 0.5f};
|
||||||
|
runningVar = {1.0f, 1.0f};
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
input = {
|
input = {
|
||||||
// Channel 0
|
// Channel 0
|
||||||
@@ -93,12 +103,12 @@ TEST_F(BatchNormLayerTest, BatchNormSmallForwardTest) {
|
|||||||
};
|
};
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
expected = {-0.06007f, 0.951f, 0.18157f, 1.36202f, 0.39244f, 0.47335f,
|
expected = {0.18029f, 0.44435f, 0.2434f, 0.5517f, 0.29847f, 0.3196f,
|
||||||
0.58598f, -1.00188f, 0.59576f, 0.79919f, -0.57001f, 0.70469f,
|
0.34902f, -0.06568f, 0.35157f, 0.4047f, 0.04711f, 0.38002f,
|
||||||
-0.62847f, -0.06578f, -0.43668f, 0.72952f, 0.37726f, 0.02088f,
|
0.03184f, 0.1788f, 0.08193f, 0.38651f, 0.55466f, 0.44578f,
|
||||||
0.35446f, 0.98092f, 1.39264f, 1.80686f, 1.67786f, 1.58318f,
|
0.54769f, 0.73908f, 0.86486f, 0.9914f, 0.952f, 0.92307f,
|
||||||
-0.0269f, 0.26878f, 0.81411f, 0.09022f, 0.9126f, 0.71485f,
|
0.43118f, 0.52152f, 0.68811f, 0.46697f, 0.7182f, 0.65779f,
|
||||||
-0.08184f, -0.19131f};
|
0.4144f, 0.38096f};
|
||||||
|
|
||||||
runTest();
|
runTest();
|
||||||
}
|
}
|
||||||
@@ -109,6 +119,9 @@ TEST_F(BatchNormLayerTest, BatchNormNonSquareInputTest) {
|
|||||||
weights = {0.63508f, 0.64903f};
|
weights = {0.63508f, 0.64903f};
|
||||||
biases = {0.25079f, 0.66841f};
|
biases = {0.25079f, 0.66841f};
|
||||||
|
|
||||||
|
runningMean = {0.5f, 0.5f};
|
||||||
|
runningVar = {1.0f, 1.0f};
|
||||||
|
|
||||||
input = {// Channel 0
|
input = {// Channel 0
|
||||||
0.38899f, 0.80478f, 0.48836f, 0.97381f, 0.21567f, 0.92312f,
|
0.38899f, 0.80478f, 0.48836f, 0.97381f, 0.21567f, 0.92312f,
|
||||||
0.57508f, 0.60835f, 0.65467f, 0.00168f, 0.31567f, 0.71345f,
|
0.57508f, 0.60835f, 0.65467f, 0.00168f, 0.31567f, 0.71345f,
|
||||||
@@ -121,16 +134,14 @@ TEST_F(BatchNormLayerTest, BatchNormNonSquareInputTest) {
|
|||||||
0.48364f, 0.10863f, 0.0571f, 0.78934f, 0.67545f
|
0.48364f, 0.10863f, 0.0571f, 0.78934f, 0.67545f
|
||||||
};
|
};
|
||||||
|
|
||||||
expected = {-0.05598f, 0.87495f, 0.1665f, 1.2534f, -0.44404f,
|
expected = {0.18029f, 0.44435f, 0.2434f, 0.5517f, 0.07022f, 0.5195f,
|
||||||
1.13991f, 0.36066f, 0.43515f, 0.53886f, -0.92315f,
|
0.29847f, 0.3196f, 0.34902f, -0.06568f, 0.13373f, 0.38635f,
|
||||||
-0.22014f, 0.67047f, 0.54786f, 0.73517f, -0.52552f,
|
0.35157f, 0.4047f, 0.04711f, 0.38002f, 0.0149f, 0.54067f,
|
||||||
0.64817f, -0.63907f, 1.21453f, -0.57934f, -0.06124f,
|
0.03184f, 0.1788f, 0.08193f, 0.38651f, 0.10338f, 0.46298f,
|
||||||
-0.40275f, 0.67103f, -0.32712f, 0.94064f, 0.28344f,
|
0.55466f, 0.44578f, 0.54769f, 0.73908f, 0.49411f, 0.85627f,
|
||||||
-0.08405f, 0.25993f, 0.90592f, 0.07909f, 1.30149f,
|
0.86486f, 0.9914f, 0.952f, 0.92307f, 0.74132f, 0.94304f,
|
||||||
1.33047f, 1.7576f, 1.62459f, 1.52695f, 0.9135f,
|
0.43118f, 0.52152f, 0.68811f, 0.46697f, 0.64f, 0.43815f,
|
||||||
1.59436f, -0.13331f, 0.17158f, 0.73391f, -0.01254f,
|
0.7182f, 0.65779f, 0.4144f, 0.38096f, 0.8562f, 0.78228f};
|
||||||
0.57151f, -0.10979f, 0.83546f, 0.63156f, -0.18996f,
|
|
||||||
-0.30285f, 1.30124f, 1.05175f};
|
|
||||||
|
|
||||||
runTest();
|
runTest();
|
||||||
}
|
}
|
||||||
@@ -5,7 +5,7 @@ from utils import print_cpp_vector
|
|||||||
|
|
||||||
def gen_batch_norm_test_result(input):
|
def gen_batch_norm_test_result(input):
|
||||||
|
|
||||||
batch_norm = torch.nn.BatchNorm2d(2, track_running_stats=False)
|
batch_norm = torch.nn.BatchNorm2d(2, track_running_stats=True)
|
||||||
|
|
||||||
weights = torch.Tensor([0.63508, 0.64903])
|
weights = torch.Tensor([0.63508, 0.64903])
|
||||||
biases = torch.Tensor([0.25079, 0.66841])
|
biases = torch.Tensor([0.25079, 0.66841])
|
||||||
@@ -13,7 +13,13 @@ def gen_batch_norm_test_result(input):
|
|||||||
batch_norm.weight = torch.nn.Parameter(weights)
|
batch_norm.weight = torch.nn.Parameter(weights)
|
||||||
batch_norm.bias = torch.nn.Parameter(biases)
|
batch_norm.bias = torch.nn.Parameter(biases)
|
||||||
|
|
||||||
|
batch_norm.running_mean = torch.Tensor([0.5, 0.5])
|
||||||
|
batch_norm.running_var = torch.Tensor([1.0, 1.0])
|
||||||
|
|
||||||
|
batch_norm.eval()
|
||||||
|
|
||||||
output = batch_norm(input)
|
output = batch_norm(input)
|
||||||
|
|
||||||
print_cpp_vector(output.flatten())
|
print_cpp_vector(output.flatten())
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,19 @@ def export_model_weights(model: torch.nn.Module, filename):
|
|||||||
|
|
||||||
tensor_data += tensor_bytes
|
tensor_data += tensor_bytes
|
||||||
|
|
||||||
|
# print(model.named_buffers)
|
||||||
|
|
||||||
|
# Add buffers (for running_mean and running_var)
|
||||||
|
for name, buf in model.named_buffers():
|
||||||
|
if "running_mean" not in name and "running_var" not in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tensor_bytes = buf.type(torch.float32).detach().numpy().tobytes()
|
||||||
|
tensor_size = buf.numel()
|
||||||
|
header += f"{name},{tensor_size},{offset}\n"
|
||||||
|
offset += len(tensor_bytes)
|
||||||
|
tensor_data += tensor_bytes
|
||||||
|
|
||||||
f.seek(0)
|
f.seek(0)
|
||||||
f.write(struct.pack("H", version))
|
f.write(struct.pack("H", version))
|
||||||
f.write(struct.pack("Q", len(header)))
|
f.write(struct.pack("Q", len(header)))
|
||||||
|
|||||||
Reference in New Issue
Block a user