Add running mean and running var to batchnorm

This commit is contained in:
2024-08-25 19:05:10 +02:00
parent 1136ca452f
commit 9704d0d53e
8 changed files with 205 additions and 71 deletions

View File

@@ -140,6 +140,19 @@ __global__ void Kernels::vec_sqrt(
}
}
__global__ void Kernels::vec_scale(
const float* __restrict__ src,
float* __restrict__ dst,
const float* __restrict__ scale,
const float* epsilon,
const unsigned int len
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < len) {
float inv_std = rsqrtf(*scale + *epsilon);
dst[idx] = src[idx] * inv_std;
}
}
__global__ void Kernels::max_reduce(
const float* __restrict__ d_vector,

View File

@@ -26,19 +26,14 @@ BatchNorm2d::BatchNorm2d(
sizeof(float) * inputSize.first * inputSize.second * inputChannels
));
d_mean = nullptr;
d_running_mean = nullptr;
CUDA_CHECK(cudaMalloc(
(void **)&d_mean, sizeof(float) * inputSize.first * inputSize.second
(void **)&d_running_mean, sizeof(float) * inputChannels
));
d_mean_sub = nullptr;
d_running_var = nullptr;
CUDA_CHECK(cudaMalloc(
(void **)&d_mean_sub, sizeof(float) * inputSize.first * inputSize.second
));
d_sqrt_var = nullptr;
CUDA_CHECK(cudaMalloc(
(void **)&d_sqrt_var, sizeof(float) * inputSize.first * inputSize.second
(void **)&d_running_var, sizeof(float) * inputChannels
));
d_weights = nullptr;
@@ -63,8 +58,13 @@ BatchNorm2d::BatchNorm2d(
weights.resize(inputChannels);
biases.resize(inputChannels);
running_mean.resize(inputChannels);
running_var.resize(inputChannels);
initializeWeights();
initializeBiases();
initializeRunningMean();
initializeRunningVar();
toCuda();
@@ -74,9 +74,8 @@ BatchNorm2d::BatchNorm2d(
BatchNorm2d::~BatchNorm2d() {
cudaFree(d_output);
cudaFree(d_mean);
cudaFree(d_mean_sub);
cudaFree(d_sqrt_var);
cudaFree(d_running_mean);
cudaFree(d_running_var);
cudaFree(d_weights);
cudaFree(d_biases);
cudaFree(d_length);
@@ -91,6 +90,14 @@ void BatchNorm2d::initializeBiases() {
std::fill(biases.begin(), biases.end(), 0.0f);
}
void BatchNorm2d::initializeRunningMean() {
std::fill(running_mean.begin(), running_mean.end(), 0.0f);
}
void BatchNorm2d::initializeRunningVar() {
std::fill(running_var.begin(), running_var.end(), 1.0f);
}
void BatchNorm2d::setWeights(const float *weights_input) {
std::copy(weights_input, weights_input + weights.size(), weights.begin());
toCuda();
@@ -109,6 +116,16 @@ std::vector<float> BatchNorm2d::getBiases() {
return biases;
}
void BatchNorm2d::setRunningMean(const float* running_mean_input) {
std::copy(running_mean_input, running_mean_input + inputChannels, running_mean.begin());
toCuda();
}
void BatchNorm2d::setRunningVar(const float* running_var_input) {
std::copy(running_var_input, running_var_input + inputChannels, running_var.begin());
toCuda();
}
void BatchNorm2d::toCuda() {
CUDA_CHECK(cudaMemcpy(
d_weights, weights.data(), sizeof(float) * inputChannels,
@@ -118,6 +135,14 @@ void BatchNorm2d::toCuda() {
d_biases, biases.data(), sizeof(float) * inputChannels,
cudaMemcpyHostToDevice
));
CUDA_CHECK(cudaMemcpy(
d_running_mean, running_mean.data(), sizeof(float) * inputChannels,
cudaMemcpyHostToDevice
));
CUDA_CHECK(cudaMemcpy(
d_running_var, running_var.data(), sizeof(float) * inputChannels,
cudaMemcpyHostToDevice
));
}
int BatchNorm2d::getInputSize() {
@@ -135,48 +160,30 @@ shape2d BatchNorm2d::getOutputDims() {
float *BatchNorm2d::forward(const float *d_input) {
// Compute per-channel batch normalization
for (int i = 0; i < inputChannels; i++) {
// Compute mean
Utils::mean(
d_input + i * inputSize.first * inputSize.second, d_mean, d_length,
inputSize.first * inputSize.second
);
// Subtract mean from input
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
d_input + i * inputSize.first * inputSize.second, d_mean_sub,
&d_mean[0], inputSize.first * inputSize.second
d_input + i * inputSize.first * inputSize.second,
d_output + i * inputSize.first * inputSize.second,
&d_running_mean[i], inputSize.first * inputSize.second
);
CUDA_CHECK(cudaGetLastError());
// Compute variance
Utils::var(
d_mean_sub, d_sqrt_var, d_length, inputSize.first * inputSize.second
);
// Add epsilon to variance to avoid division by zero
Kernels::vec_scalar_add<<<gridSize, BLOCK_SIZE>>>(
d_sqrt_var, d_sqrt_var, &d_epsilon[0],
// Divide by sqrt(running_var + epsilon)
Kernels::vec_scale<<<gridSize, BLOCK_SIZE>>>(
d_output + i * inputSize.first * inputSize.second,
d_output + i * inputSize.first * inputSize.second,
&d_running_var[i],
d_epsilon,
inputSize.first * inputSize.second
);
CUDA_CHECK(cudaGetLastError());
// Compute squared root of variance
Kernels::vec_sqrt<<<gridSize, BLOCK_SIZE>>>(
d_sqrt_var, d_sqrt_var, inputSize.first * inputSize.second
);
CUDA_CHECK(cudaGetLastError());
// Divide by squared root of variance
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
d_mean_sub, d_output + i * inputSize.first * inputSize.second,
&d_sqrt_var[0], inputSize.first * inputSize.second
);
CUDA_CHECK(cudaGetLastError());
// Multiply by weights
Kernels::vec_scalar_mul<<<gridSize, BLOCK_SIZE>>>(
d_output + i * inputSize.first * inputSize.second,
d_output + i * inputSize.first * inputSize.second, &d_weights[i],
d_output + i * inputSize.first * inputSize.second,
&d_weights[i],
inputSize.first * inputSize.second
);
CUDA_CHECK(cudaGetLastError());
@@ -184,7 +191,8 @@ float *BatchNorm2d::forward(const float *d_input) {
// Add biases
Kernels::vec_scalar_add<<<gridSize, BLOCK_SIZE>>>(
d_output + i * inputSize.first * inputSize.second,
d_output + i * inputSize.first * inputSize.second, &d_biases[i],
d_output + i * inputSize.first * inputSize.second,
&d_biases[i],
inputSize.first * inputSize.second
);
CUDA_CHECK(cudaGetLastError());