mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 09:44:28 +00:00
Add running mean and running var to batchnorm
This commit is contained in:
@@ -140,6 +140,19 @@ __global__ void Kernels::vec_sqrt(
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Kernels::vec_scale(
|
||||
const float* __restrict__ src,
|
||||
float* __restrict__ dst,
|
||||
const float* __restrict__ scale,
|
||||
const float* epsilon,
|
||||
const unsigned int len
|
||||
) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < len) {
|
||||
float inv_std = rsqrtf(*scale + *epsilon);
|
||||
dst[idx] = src[idx] * inv_std;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Kernels::max_reduce(
|
||||
const float* __restrict__ d_vector,
|
||||
|
||||
@@ -26,19 +26,14 @@ BatchNorm2d::BatchNorm2d(
|
||||
sizeof(float) * inputSize.first * inputSize.second * inputChannels
|
||||
));
|
||||
|
||||
d_mean = nullptr;
|
||||
d_running_mean = nullptr;
|
||||
CUDA_CHECK(cudaMalloc(
|
||||
(void **)&d_mean, sizeof(float) * inputSize.first * inputSize.second
|
||||
(void **)&d_running_mean, sizeof(float) * inputChannels
|
||||
));
|
||||
|
||||
d_mean_sub = nullptr;
|
||||
d_running_var = nullptr;
|
||||
CUDA_CHECK(cudaMalloc(
|
||||
(void **)&d_mean_sub, sizeof(float) * inputSize.first * inputSize.second
|
||||
));
|
||||
|
||||
d_sqrt_var = nullptr;
|
||||
CUDA_CHECK(cudaMalloc(
|
||||
(void **)&d_sqrt_var, sizeof(float) * inputSize.first * inputSize.second
|
||||
(void **)&d_running_var, sizeof(float) * inputChannels
|
||||
));
|
||||
|
||||
d_weights = nullptr;
|
||||
@@ -63,8 +58,13 @@ BatchNorm2d::BatchNorm2d(
|
||||
weights.resize(inputChannels);
|
||||
biases.resize(inputChannels);
|
||||
|
||||
running_mean.resize(inputChannels);
|
||||
running_var.resize(inputChannels);
|
||||
|
||||
initializeWeights();
|
||||
initializeBiases();
|
||||
initializeRunningMean();
|
||||
initializeRunningVar();
|
||||
|
||||
toCuda();
|
||||
|
||||
@@ -74,9 +74,8 @@ BatchNorm2d::BatchNorm2d(
|
||||
|
||||
BatchNorm2d::~BatchNorm2d() {
|
||||
cudaFree(d_output);
|
||||
cudaFree(d_mean);
|
||||
cudaFree(d_mean_sub);
|
||||
cudaFree(d_sqrt_var);
|
||||
cudaFree(d_running_mean);
|
||||
cudaFree(d_running_var);
|
||||
cudaFree(d_weights);
|
||||
cudaFree(d_biases);
|
||||
cudaFree(d_length);
|
||||
@@ -91,6 +90,14 @@ void BatchNorm2d::initializeBiases() {
|
||||
std::fill(biases.begin(), biases.end(), 0.0f);
|
||||
}
|
||||
|
||||
void BatchNorm2d::initializeRunningMean() {
|
||||
std::fill(running_mean.begin(), running_mean.end(), 0.0f);
|
||||
}
|
||||
|
||||
void BatchNorm2d::initializeRunningVar() {
|
||||
std::fill(running_var.begin(), running_var.end(), 1.0f);
|
||||
}
|
||||
|
||||
void BatchNorm2d::setWeights(const float *weights_input) {
|
||||
std::copy(weights_input, weights_input + weights.size(), weights.begin());
|
||||
toCuda();
|
||||
@@ -109,6 +116,16 @@ std::vector<float> BatchNorm2d::getBiases() {
|
||||
return biases;
|
||||
}
|
||||
|
||||
void BatchNorm2d::setRunningMean(const float* running_mean_input) {
|
||||
std::copy(running_mean_input, running_mean_input + inputChannels, running_mean.begin());
|
||||
toCuda();
|
||||
}
|
||||
|
||||
void BatchNorm2d::setRunningVar(const float* running_var_input) {
|
||||
std::copy(running_var_input, running_var_input + inputChannels, running_var.begin());
|
||||
toCuda();
|
||||
}
|
||||
|
||||
void BatchNorm2d::toCuda() {
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_weights, weights.data(), sizeof(float) * inputChannels,
|
||||
@@ -118,6 +135,14 @@ void BatchNorm2d::toCuda() {
|
||||
d_biases, biases.data(), sizeof(float) * inputChannels,
|
||||
cudaMemcpyHostToDevice
|
||||
));
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_running_mean, running_mean.data(), sizeof(float) * inputChannels,
|
||||
cudaMemcpyHostToDevice
|
||||
));
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_running_var, running_var.data(), sizeof(float) * inputChannels,
|
||||
cudaMemcpyHostToDevice
|
||||
));
|
||||
}
|
||||
|
||||
int BatchNorm2d::getInputSize() {
|
||||
@@ -135,48 +160,30 @@ shape2d BatchNorm2d::getOutputDims() {
|
||||
float *BatchNorm2d::forward(const float *d_input) {
|
||||
// Compute per-channel batch normalization
|
||||
for (int i = 0; i < inputChannels; i++) {
|
||||
// Compute mean
|
||||
Utils::mean(
|
||||
d_input + i * inputSize.first * inputSize.second, d_mean, d_length,
|
||||
inputSize.first * inputSize.second
|
||||
);
|
||||
|
||||
// Subtract mean from input
|
||||
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_input + i * inputSize.first * inputSize.second, d_mean_sub,
|
||||
&d_mean[0], inputSize.first * inputSize.second
|
||||
d_input + i * inputSize.first * inputSize.second,
|
||||
d_output + i * inputSize.first * inputSize.second,
|
||||
&d_running_mean[i], inputSize.first * inputSize.second
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Compute variance
|
||||
Utils::var(
|
||||
d_mean_sub, d_sqrt_var, d_length, inputSize.first * inputSize.second
|
||||
);
|
||||
|
||||
// Add epsilon to variance to avoid division by zero
|
||||
Kernels::vec_scalar_add<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_sqrt_var, d_sqrt_var, &d_epsilon[0],
|
||||
// Divide by sqrt(running_var + epsilon)
|
||||
Kernels::vec_scale<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_output + i * inputSize.first * inputSize.second,
|
||||
d_output + i * inputSize.first * inputSize.second,
|
||||
&d_running_var[i],
|
||||
d_epsilon,
|
||||
inputSize.first * inputSize.second
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Compute squared root of variance
|
||||
Kernels::vec_sqrt<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_sqrt_var, d_sqrt_var, inputSize.first * inputSize.second
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Divide by squared root of variance
|
||||
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_mean_sub, d_output + i * inputSize.first * inputSize.second,
|
||||
&d_sqrt_var[0], inputSize.first * inputSize.second
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Multiply by weights
|
||||
Kernels::vec_scalar_mul<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_output + i * inputSize.first * inputSize.second,
|
||||
d_output + i * inputSize.first * inputSize.second, &d_weights[i],
|
||||
d_output + i * inputSize.first * inputSize.second,
|
||||
&d_weights[i],
|
||||
inputSize.first * inputSize.second
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
@@ -184,7 +191,8 @@ float *BatchNorm2d::forward(const float *d_input) {
|
||||
// Add biases
|
||||
Kernels::vec_scalar_add<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_output + i * inputSize.first * inputSize.second,
|
||||
d_output + i * inputSize.first * inputSize.second, &d_biases[i],
|
||||
d_output + i * inputSize.first * inputSize.second,
|
||||
&d_biases[i],
|
||||
inputSize.first * inputSize.second
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
Reference in New Issue
Block a user