mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Compute mean and variance
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
#include "cuda_helper.cuh"
|
||||
#include "layer.cuh"
|
||||
#include "matmul.cuh"
|
||||
#include "vector.cuh"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
@@ -24,10 +25,15 @@ BatchNorm::BatchNorm(
|
||||
));
|
||||
|
||||
d_mean = nullptr;
|
||||
CUDA_CHECK(cudaMalloc((void **)&d_mean, sizeof(float) * inputChannels));
|
||||
CUDA_CHECK(cudaMalloc((void **)&d_mean, sizeof(float) * inputSize * inputSize));
|
||||
|
||||
d_mean_sub = nullptr;
|
||||
CUDA_CHECK(
|
||||
cudaMalloc((void **)&d_mean_sub, sizeof(float) * inputSize * inputSize)
|
||||
);
|
||||
|
||||
d_sqrt_var = nullptr;
|
||||
CUDA_CHECK(cudaMalloc((void **)&d_sqrt_var, sizeof(float) * inputChannels));
|
||||
CUDA_CHECK(cudaMalloc((void **)&d_sqrt_var, sizeof(float) * inputSize * inputSize));
|
||||
|
||||
d_weights = nullptr;
|
||||
CUDA_CHECK(cudaMalloc((void **)&d_weights, sizeof(float) * inputChannels));
|
||||
@@ -35,28 +41,36 @@ BatchNorm::BatchNorm(
|
||||
d_biases = nullptr;
|
||||
CUDA_CHECK(cudaMalloc((void **)&d_biases, sizeof(float) * inputChannels));
|
||||
|
||||
d_length = nullptr;
|
||||
CUDA_CHECK(cudaMalloc((void **)&d_length, sizeof(float)));
|
||||
CUDA_CHECK(cudaMemset(d_length, inputSize * inputSize, sizeof(float)));
|
||||
|
||||
d_epsilon = nullptr;
|
||||
float epsilon = 1e-5f;
|
||||
CUDA_CHECK(cudaMalloc((void **)&d_epsilon, sizeof(float)));
|
||||
CUDA_CHECK(cudaMemcpy(d_epsilon, &epsilon, sizeof(float), cudaMemcpyHostToDevice));
|
||||
|
||||
weights.resize(inputChannels);
|
||||
biases.resize(inputChannels);
|
||||
mean.resize(inputChannels);
|
||||
sqrt_var.resize(inputChannels);
|
||||
|
||||
initializeWeights();
|
||||
initializeBiases();
|
||||
initializeMean();
|
||||
initializeSqrtVar();
|
||||
|
||||
toCuda();
|
||||
|
||||
gridSize =
|
||||
(inputSize * inputSize * inputChannels + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
(inputSize * inputSize + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
}
|
||||
|
||||
BatchNorm::~BatchNorm() {
|
||||
cudaFree(d_output);
|
||||
cudaFree(d_mean);
|
||||
cudaFree(d_mean_sub);
|
||||
cudaFree(d_sqrt_var);
|
||||
cudaFree(d_weights);
|
||||
cudaFree(d_biases);
|
||||
cudaFree(d_length);
|
||||
cudaFree(d_epsilon);
|
||||
}
|
||||
|
||||
void BatchNorm::initializeWeights() {
|
||||
@@ -67,14 +81,6 @@ void BatchNorm::initializeBiases() {
|
||||
std::fill(biases.begin(), biases.end(), 0.0f);
|
||||
}
|
||||
|
||||
void BatchNorm::initializeMean() {
|
||||
std::fill(mean.begin(), mean.end(), 0.0f);
|
||||
}
|
||||
|
||||
void BatchNorm::initializeSqrtVar() {
|
||||
std::fill(sqrt_var.begin(), sqrt_var.end(), 1.0f);
|
||||
}
|
||||
|
||||
void BatchNorm::setWeights(const float *weights_input) {
|
||||
std::copy(weights_input, weights_input + weights.size(), weights.begin());
|
||||
toCuda();
|
||||
@@ -102,14 +108,6 @@ void BatchNorm::toCuda() {
|
||||
d_biases, biases.data(), sizeof(float) * inputChannels,
|
||||
cudaMemcpyHostToDevice
|
||||
));
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_mean, mean.data(), sizeof(float) * inputChannels,
|
||||
cudaMemcpyHostToDevice
|
||||
));
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_sqrt_var, sqrt_var.data(), sizeof(float) * inputChannels,
|
||||
cudaMemcpyHostToDevice
|
||||
));
|
||||
}
|
||||
|
||||
int BatchNorm::getInputSize() {
|
||||
@@ -122,19 +120,83 @@ int BatchNorm::getOutputSize() {
|
||||
|
||||
float *BatchNorm::forward(const float *d_input) {
|
||||
|
||||
// Compute per-channel batch normalization
|
||||
for (int i = 0; i < inputChannels; i++) {
|
||||
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
|
||||
|
||||
// Compute mean
|
||||
// Sum over all values
|
||||
Utils::sum(
|
||||
d_input + i * inputSize * inputSize,
|
||||
d_output + i * inputSize * inputSize,
|
||||
&d_mean[i],
|
||||
d_mean,
|
||||
inputSize * inputSize
|
||||
);
|
||||
|
||||
// Divide sum by length -> mean
|
||||
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_mean,
|
||||
d_mean,
|
||||
d_length,
|
||||
inputSize * inputSize
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Subtract mean from input
|
||||
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_input + i * inputSize * inputSize,
|
||||
d_mean_sub,
|
||||
&d_mean[0],
|
||||
inputSize * inputSize
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Compute variance
|
||||
// Square differences of input - mean
|
||||
Kernels::vec_vec_mul<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_mean_sub,
|
||||
d_mean_sub,
|
||||
d_sqrt_var,
|
||||
inputSize * inputSize
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Sum over all differences
|
||||
Utils::sum(
|
||||
d_sqrt_var,
|
||||
d_sqrt_var,
|
||||
inputSize * inputSize
|
||||
);
|
||||
|
||||
// Divide by difference sum / length -> variance
|
||||
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_sqrt_var,
|
||||
d_sqrt_var,
|
||||
d_length,
|
||||
inputSize * inputSize
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Add epsilon to variance to avoid division by zero
|
||||
Kernels::vec_scalar_add<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_sqrt_var,
|
||||
d_sqrt_var,
|
||||
&d_epsilon[0],
|
||||
inputSize * inputSize
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Compute squared root of variance
|
||||
Kernels::vec_sqrt<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_sqrt_var,
|
||||
d_sqrt_var,
|
||||
inputSize * inputSize
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Divide by squared root of variance
|
||||
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_mean_sub,
|
||||
d_output + i * inputSize * inputSize,
|
||||
d_output + i * inputSize * inputSize,
|
||||
&d_sqrt_var[i],
|
||||
&d_sqrt_var[0],
|
||||
inputSize * inputSize
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
Reference in New Issue
Block a user