Compute mean and variance

This commit is contained in:
2024-04-29 20:55:11 +02:00
parent 0ab623fa23
commit 5c8d3f7e25
6 changed files with 143 additions and 59 deletions

View File

@@ -135,6 +135,20 @@ __global__ void vec_exp(
const unsigned int len const unsigned int len
); );
/**
* @brief Compute the square root of each element of the vector
*
* @param src Device pointer to source vector
* @param dst Device pointer to destination vector
* @param len Length of the vector
* @return __global__
*/
__global__ void vec_sqrt(
const float* __restrict__ src,
float* __restrict__ dst,
const unsigned int len
);
/** /**
* @brief Max reduction kernel * @brief Max reduction kernel
* *

View File

@@ -74,8 +74,12 @@ class BatchNorm : public WeightedLayer {
float* d_output; float* d_output;
float* d_mean; float* d_mean;
float* d_mean_sub;
float* d_sqrt_var; float* d_sqrt_var;
float* d_length;
float* d_epsilon;
float* d_weights; float* d_weights;
float* d_biases; float* d_biases;

View File

@@ -10,7 +10,7 @@ namespace CUDANet::Utils {
* @param d_vec Pointer to the vector on device * @param d_vec Pointer to the vector on device
* @param length Length of the vector * @param length Length of the vector
*/ */
void print_vec(float *d_vec, const unsigned int length); void print_vec(const float *d_vec, const unsigned int length);
/** /**
* @brief Utility function that clears a vector * @brief Utility function that clears a vector
@@ -27,7 +27,7 @@ void clear(float *d_vector, const unsigned int len);
* @param d_vec Pointer to the vector * @param d_vec Pointer to the vector
* @param length Length of the vector * @param length Length of the vector
*/ */
void sum(float *d_vec, float *d_sum, const unsigned int length); void sum(const float *d_vec, float *d_sum, const unsigned int length);
/** /**
@@ -36,25 +36,16 @@ void sum(float *d_vec, float *d_sum, const unsigned int length);
* @param d_vec Pointer to the vector * @param d_vec Pointer to the vector
* @param length Length of the vector * @param length Length of the vector
*/ */
void max(float *d_vec, float *d_max, const unsigned int length); void max(const float *d_vec, float *d_max, const unsigned int length);
/** // /**
* @brief Compute the mean of a vector // * @brief Compute the variance of a vector
* // *
* @param d_vec // * @param d_vec
* @param d_mean // * @param d_var
* @param length // * @param length
*/ // */
void mean(float *d_vec, float *d_mean, const unsigned int length); // void var(float *d_vec, float *d_var, const unsigned int length);
/**
* @brief Compute the variance of a vector
*
* @param d_vec
* @param d_var
* @param length
*/
void var(float *d_vec, float *d_var, const unsigned int length);
} // namespace CUDANet::Utils } // namespace CUDANet::Utils

View File

@@ -127,6 +127,19 @@ __global__ void Kernels::vec_exp(
} }
} }
__global__ void Kernels::vec_sqrt(
const float* __restrict__ src,
float* __restrict__ dst,
const unsigned int len
) {
int stride = gridDim.x * blockDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < len; i += stride) {
dst[i] = sqrtf(src[i]);
}
}
__global__ void Kernels::max_reduce( __global__ void Kernels::max_reduce(
const float* __restrict__ d_vector, const float* __restrict__ d_vector,

View File

@@ -5,6 +5,7 @@
#include "cuda_helper.cuh" #include "cuda_helper.cuh"
#include "layer.cuh" #include "layer.cuh"
#include "matmul.cuh" #include "matmul.cuh"
#include "vector.cuh"
using namespace CUDANet::Layers; using namespace CUDANet::Layers;
@@ -24,10 +25,15 @@ BatchNorm::BatchNorm(
)); ));
d_mean = nullptr; d_mean = nullptr;
CUDA_CHECK(cudaMalloc((void **)&d_mean, sizeof(float) * inputChannels)); CUDA_CHECK(cudaMalloc((void **)&d_mean, sizeof(float) * inputSize * inputSize));
d_mean_sub = nullptr;
CUDA_CHECK(
cudaMalloc((void **)&d_mean_sub, sizeof(float) * inputSize * inputSize)
);
d_sqrt_var = nullptr; d_sqrt_var = nullptr;
CUDA_CHECK(cudaMalloc((void **)&d_sqrt_var, sizeof(float) * inputChannels)); CUDA_CHECK(cudaMalloc((void **)&d_sqrt_var, sizeof(float) * inputSize * inputSize));
d_weights = nullptr; d_weights = nullptr;
CUDA_CHECK(cudaMalloc((void **)&d_weights, sizeof(float) * inputChannels)); CUDA_CHECK(cudaMalloc((void **)&d_weights, sizeof(float) * inputChannels));
@@ -35,28 +41,36 @@ BatchNorm::BatchNorm(
d_biases = nullptr; d_biases = nullptr;
CUDA_CHECK(cudaMalloc((void **)&d_biases, sizeof(float) * inputChannels)); CUDA_CHECK(cudaMalloc((void **)&d_biases, sizeof(float) * inputChannels));
d_length = nullptr;
CUDA_CHECK(cudaMalloc((void **)&d_length, sizeof(float)));
CUDA_CHECK(cudaMemset(d_length, inputSize * inputSize, sizeof(float)));
d_epsilon = nullptr;
float epsilon = 1e-5f;
CUDA_CHECK(cudaMalloc((void **)&d_epsilon, sizeof(float)));
CUDA_CHECK(cudaMemcpy(d_epsilon, &epsilon, sizeof(float), cudaMemcpyHostToDevice));
weights.resize(inputChannels); weights.resize(inputChannels);
biases.resize(inputChannels); biases.resize(inputChannels);
mean.resize(inputChannels);
sqrt_var.resize(inputChannels);
initializeWeights(); initializeWeights();
initializeBiases(); initializeBiases();
initializeMean();
initializeSqrtVar();
toCuda(); toCuda();
gridSize = gridSize =
(inputSize * inputSize * inputChannels + BLOCK_SIZE - 1) / BLOCK_SIZE; (inputSize * inputSize + BLOCK_SIZE - 1) / BLOCK_SIZE;
} }
BatchNorm::~BatchNorm() { BatchNorm::~BatchNorm() {
cudaFree(d_output); cudaFree(d_output);
cudaFree(d_mean); cudaFree(d_mean);
cudaFree(d_mean_sub);
cudaFree(d_sqrt_var); cudaFree(d_sqrt_var);
cudaFree(d_weights); cudaFree(d_weights);
cudaFree(d_biases); cudaFree(d_biases);
cudaFree(d_length);
cudaFree(d_epsilon);
} }
void BatchNorm::initializeWeights() { void BatchNorm::initializeWeights() {
@@ -67,14 +81,6 @@ void BatchNorm::initializeBiases() {
std::fill(biases.begin(), biases.end(), 0.0f); std::fill(biases.begin(), biases.end(), 0.0f);
} }
void BatchNorm::initializeMean() {
std::fill(mean.begin(), mean.end(), 0.0f);
}
void BatchNorm::initializeSqrtVar() {
std::fill(sqrt_var.begin(), sqrt_var.end(), 1.0f);
}
void BatchNorm::setWeights(const float *weights_input) { void BatchNorm::setWeights(const float *weights_input) {
std::copy(weights_input, weights_input + weights.size(), weights.begin()); std::copy(weights_input, weights_input + weights.size(), weights.begin());
toCuda(); toCuda();
@@ -102,14 +108,6 @@ void BatchNorm::toCuda() {
d_biases, biases.data(), sizeof(float) * inputChannels, d_biases, biases.data(), sizeof(float) * inputChannels,
cudaMemcpyHostToDevice cudaMemcpyHostToDevice
)); ));
CUDA_CHECK(cudaMemcpy(
d_mean, mean.data(), sizeof(float) * inputChannels,
cudaMemcpyHostToDevice
));
CUDA_CHECK(cudaMemcpy(
d_sqrt_var, sqrt_var.data(), sizeof(float) * inputChannels,
cudaMemcpyHostToDevice
));
} }
int BatchNorm::getInputSize() { int BatchNorm::getInputSize() {
@@ -122,19 +120,83 @@ int BatchNorm::getOutputSize() {
float *BatchNorm::forward(const float *d_input) { float *BatchNorm::forward(const float *d_input) {
// Compute per-channel batch normalization
for (int i = 0; i < inputChannels; i++) { for (int i = 0; i < inputChannels; i++) {
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
// Compute mean
// Sum over all values
Utils::sum(
d_input + i * inputSize * inputSize, d_input + i * inputSize * inputSize,
d_output + i * inputSize * inputSize, d_mean,
&d_mean[i], inputSize * inputSize
);
// Divide sum by length -> mean
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
d_mean,
d_mean,
d_length,
inputSize * inputSize inputSize * inputSize
); );
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
// Subtract mean from input
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
d_input + i * inputSize * inputSize,
d_mean_sub,
&d_mean[0],
inputSize * inputSize
);
CUDA_CHECK(cudaGetLastError());
// Compute variance
// Square differences of input - mean
Kernels::vec_vec_mul<<<gridSize, BLOCK_SIZE>>>(
d_mean_sub,
d_mean_sub,
d_sqrt_var,
inputSize * inputSize
);
CUDA_CHECK(cudaGetLastError());
// Sum over all differences
Utils::sum(
d_sqrt_var,
d_sqrt_var,
inputSize * inputSize
);
// Divide by difference sum / length -> variance
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>( Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
d_sqrt_var,
d_sqrt_var,
d_length,
inputSize * inputSize
);
CUDA_CHECK(cudaGetLastError());
// Add epsilon to variance to avoid division by zero
Kernels::vec_scalar_add<<<gridSize, BLOCK_SIZE>>>(
d_sqrt_var,
d_sqrt_var,
&d_epsilon[0],
inputSize * inputSize
);
CUDA_CHECK(cudaGetLastError());
// Compute squared root of variance
Kernels::vec_sqrt<<<gridSize, BLOCK_SIZE>>>(
d_sqrt_var,
d_sqrt_var,
inputSize * inputSize
);
CUDA_CHECK(cudaGetLastError());
// Divide by squared root of variance
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
d_mean_sub,
d_output + i * inputSize * inputSize, d_output + i * inputSize * inputSize,
d_output + i * inputSize * inputSize, &d_sqrt_var[0],
&d_sqrt_var[i],
inputSize * inputSize inputSize * inputSize
); );
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());

View File

@@ -7,7 +7,7 @@
using namespace CUDANet; using namespace CUDANet;
void Utils::print_vec(float* d_vec, const unsigned int length) { void Utils::print_vec(const float* d_vec, const unsigned int length) {
std::vector<float> h_vec(length); std::vector<float> h_vec(length);
CUDA_CHECK(cudaMemcpy( CUDA_CHECK(cudaMemcpy(
h_vec.data(), d_vec, sizeof(float) * length, cudaMemcpyDeviceToHost h_vec.data(), d_vec, sizeof(float) * length, cudaMemcpyDeviceToHost
@@ -24,7 +24,7 @@ void Utils::clear(float* d_vec, const unsigned int length) {
CUDA_CHECK(cudaMemset(d_vec, 0, sizeof(float) * length)); CUDA_CHECK(cudaMemset(d_vec, 0, sizeof(float) * length));
} }
void Utils::max(float* d_vec, float* d_max, const unsigned int length) { void Utils::max(const float* d_vec, float* d_max, const unsigned int length) {
const int grid_size = (length + BLOCK_SIZE - 1) / BLOCK_SIZE; const int grid_size = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
Kernels::max_reduce<<<grid_size, BLOCK_SIZE>>>(d_vec, d_max, length); Kernels::max_reduce<<<grid_size, BLOCK_SIZE>>>(d_vec, d_max, length);
@@ -42,7 +42,7 @@ void Utils::max(float* d_vec, float* d_max, const unsigned int length) {
} }
void Utils::sum(float* d_vec, float* d_sum, const unsigned int length) { void Utils::sum(const float* d_vec, float* d_sum, const unsigned int length) {
const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE; const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
@@ -61,14 +61,14 @@ void Utils::sum(float* d_vec, float* d_sum, const unsigned int length) {
} }
} }
void Utils::mean(float* d_vec, float* d_mean, const unsigned int length) { // __device__ float Utils::mean(float* d_vec, const unsigned int length) {
float sum; // float sum = 0;
Utils::sum(d_vec, &sum, length); // for (int i = 0; i < length; ++i) {
*d_mean = sum / length; // sum += d_vec[i];
} // }
void Utils::var(float* d_vec, float* d_mean, float* d_var, const unsigned int length) { // void Utils::var(float* d_vec, float* d_mean, float* d_var, const unsigned int length) {
// TODO: // // TODO:
} // }