diff --git a/include/kernels/matmul.cuh b/include/kernels/matmul.cuh index 48512e1..27a4b24 100644 --- a/include/kernels/matmul.cuh +++ b/include/kernels/matmul.cuh @@ -135,6 +135,20 @@ __global__ void vec_exp( const unsigned int len ); +/** + * @brief Compute the square root of each element of the vector + * + * @param src Device pointer to source vector + * @param dst Device pointer to destination vector + * @param len Length of the vector + * @return __global__ + */ +__global__ void vec_sqrt( + const float* __restrict__ src, + float* __restrict__ dst, + const unsigned int len +); + /** * @brief Max reduction kernel * diff --git a/include/layers/batch_norm.cuh b/include/layers/batch_norm.cuh index 142c467..cd4b623 100644 --- a/include/layers/batch_norm.cuh +++ b/include/layers/batch_norm.cuh @@ -74,8 +74,12 @@ class BatchNorm : public WeightedLayer { float* d_output; float* d_mean; + float* d_mean_sub; float* d_sqrt_var; + float* d_length; + float* d_epsilon; + float* d_weights; float* d_biases; diff --git a/include/utils/vector.cuh b/include/utils/vector.cuh index 0527e21..6817bcb 100644 --- a/include/utils/vector.cuh +++ b/include/utils/vector.cuh @@ -10,7 +10,7 @@ namespace CUDANet::Utils { * @param d_vec Pointer to the vector on device * @param length Length of the vector */ -void print_vec(float *d_vec, const unsigned int length); +void print_vec(const float *d_vec, const unsigned int length); /** * @brief Utility function that clears a vector @@ -27,7 +27,7 @@ void clear(float *d_vector, const unsigned int len); * @param d_vec Pointer to the vector * @param length Length of the vector */ -void sum(float *d_vec, float *d_sum, const unsigned int length); +void sum(const float *d_vec, float *d_sum, const unsigned int length); /** @@ -36,25 +36,16 @@ void sum(float *d_vec, float *d_sum, const unsigned int length); * @param d_vec Pointer to the vector * @param length Length of the vector */ -void max(float *d_vec, float *d_max, const unsigned int length); +void max(const float *d_vec, float *d_max, const unsigned int length); -/** - * @brief Compute the mean of a vector - * - * @param d_vec - * @param d_mean - * @param length - */ -void mean(float *d_vec, float *d_mean, const unsigned int length); - -/** - * @brief Compute the variance of a vector - * - * @param d_vec - * @param d_var - * @param length - */ -void var(float *d_vec, float *d_var, const unsigned int length); +// /** +// * @brief Compute the variance of a vector +// * +// * @param d_vec +// * @param d_var +// * @param length +// */ +// void var(float *d_vec, float *d_var, const unsigned int length); } // namespace CUDANet::Utils diff --git a/src/kernels/matmul.cu b/src/kernels/matmul.cu index 8a8e9e0..e4ef8da 100644 --- a/src/kernels/matmul.cu +++ b/src/kernels/matmul.cu @@ -127,6 +127,19 @@ __global__ void Kernels::vec_exp( } } +__global__ void Kernels::vec_sqrt( + const float* __restrict__ src, + float* __restrict__ dst, + const unsigned int len +) { + int stride = gridDim.x * blockDim.x; + int tid = blockDim.x * blockIdx.x + threadIdx.x; + + for (int i = tid; i < len; i += stride) { + dst[i] = sqrtf(src[i]); + } +} + __global__ void Kernels::max_reduce( const float* __restrict__ d_vector, diff --git a/src/layers/batch_norm.cu b/src/layers/batch_norm.cu index 1df5967..27f8876 100644 --- a/src/layers/batch_norm.cu +++ b/src/layers/batch_norm.cu @@ -5,6 +5,7 @@ #include "cuda_helper.cuh" #include "layer.cuh" #include "matmul.cuh" +#include "vector.cuh" using namespace CUDANet::Layers; @@ -24,10 +25,15 @@ BatchNorm::BatchNorm( )); d_mean = nullptr; - CUDA_CHECK(cudaMalloc((void **)&d_mean, sizeof(float) * inputChannels)); + CUDA_CHECK(cudaMalloc((void **)&d_mean, sizeof(float) * inputSize * inputSize)); + + d_mean_sub = nullptr; + CUDA_CHECK( + cudaMalloc((void **)&d_mean_sub, sizeof(float) * inputSize * inputSize) + ); d_sqrt_var = nullptr; - CUDA_CHECK(cudaMalloc((void **)&d_sqrt_var, sizeof(float) * inputChannels)); + CUDA_CHECK(cudaMalloc((void **)&d_sqrt_var, sizeof(float) * inputSize * inputSize)); d_weights = nullptr; CUDA_CHECK(cudaMalloc((void **)&d_weights, sizeof(float) * inputChannels)); @@ -35,28 +41,36 @@ BatchNorm::BatchNorm( d_biases = nullptr; CUDA_CHECK(cudaMalloc((void **)&d_biases, sizeof(float) * inputChannels)); + d_length = nullptr; + CUDA_CHECK(cudaMalloc((void **)&d_length, sizeof(float))); + CUDA_CHECK(cudaMemset(d_length, inputSize * inputSize, sizeof(float))); + + d_epsilon = nullptr; + float epsilon = 1e-5f; + CUDA_CHECK(cudaMalloc((void **)&d_epsilon, sizeof(float))); + CUDA_CHECK(cudaMemcpy(d_epsilon, &epsilon, sizeof(float), cudaMemcpyHostToDevice)); + weights.resize(inputChannels); biases.resize(inputChannels); - mean.resize(inputChannels); - sqrt_var.resize(inputChannels); initializeWeights(); initializeBiases(); - initializeMean(); - initializeSqrtVar(); toCuda(); gridSize = - (inputSize * inputSize * inputChannels + BLOCK_SIZE - 1) / BLOCK_SIZE; + (inputSize * inputSize + BLOCK_SIZE - 1) / BLOCK_SIZE; } BatchNorm::~BatchNorm() { cudaFree(d_output); cudaFree(d_mean); + cudaFree(d_mean_sub); cudaFree(d_sqrt_var); cudaFree(d_weights); cudaFree(d_biases); + cudaFree(d_length); + cudaFree(d_epsilon); } void BatchNorm::initializeWeights() { @@ -67,14 +81,6 @@ void BatchNorm::initializeBiases() { std::fill(biases.begin(), biases.end(), 0.0f); } -void BatchNorm::initializeMean() { - std::fill(mean.begin(), mean.end(), 0.0f); -} - -void BatchNorm::initializeSqrtVar() { - std::fill(sqrt_var.begin(), sqrt_var.end(), 1.0f); -} - void BatchNorm::setWeights(const float *weights_input) { std::copy(weights_input, weights_input + weights.size(), weights.begin()); toCuda(); @@ -102,14 +108,6 @@ void BatchNorm::toCuda() { d_biases, biases.data(), sizeof(float) * inputChannels, cudaMemcpyHostToDevice )); - CUDA_CHECK(cudaMemcpy( - d_mean, mean.data(), sizeof(float) * inputChannels, - cudaMemcpyHostToDevice - )); - CUDA_CHECK(cudaMemcpy( - d_sqrt_var, sqrt_var.data(), sizeof(float) * inputChannels, - cudaMemcpyHostToDevice - )); } int BatchNorm::getInputSize() { @@ -122,19 +120,83 @@ int BatchNorm::getOutputSize() { float *BatchNorm::forward(const float *d_input) { + // Compute per-channel batch normalization for (int i = 0; i < inputChannels; i++) { - Kernels::vec_scalar_sub<<>>( + + // Compute mean + // Sum over all values + Utils::sum( d_input + i * inputSize * inputSize, - d_output + i * inputSize * inputSize, - &d_mean[i], + d_mean, + inputSize * inputSize + ); + + // Divide sum by length -> mean + Kernels::vec_scalar_div<<>>( + d_mean, + d_mean, + d_length, inputSize * inputSize ); CUDA_CHECK(cudaGetLastError()); + // Subtract mean from input + Kernels::vec_scalar_sub<<>>( + d_input + i * inputSize * inputSize, + d_mean_sub, + &d_mean[0], + inputSize * inputSize + ); + CUDA_CHECK(cudaGetLastError()); + + // Compute variance + // Square differences of input - mean + Kernels::vec_vec_mul<<>>( + d_mean_sub, + d_mean_sub, + d_sqrt_var, + inputSize * inputSize + ); + CUDA_CHECK(cudaGetLastError()); + + // Sum over all differences + Utils::sum( + d_sqrt_var, + d_sqrt_var, + inputSize * inputSize + ); + + // Divide by difference sum / length -> variance Kernels::vec_scalar_div<<>>( + d_sqrt_var, + d_sqrt_var, + d_length, + inputSize * inputSize + ); + CUDA_CHECK(cudaGetLastError()); + + // Add epsilon to variance to avoid division by zero + Kernels::vec_scalar_add<<>>( + d_sqrt_var, + d_sqrt_var, + &d_epsilon[0], + inputSize * inputSize + ); + CUDA_CHECK(cudaGetLastError()); + + // Compute squared root of variance + Kernels::vec_sqrt<<>>( + d_sqrt_var, + d_sqrt_var, + inputSize * inputSize + ); + CUDA_CHECK(cudaGetLastError()); + + // Divide by squared root of variance + Kernels::vec_scalar_div<<>>( + d_mean_sub, d_output + i * inputSize * inputSize, - d_output + i * inputSize * inputSize, - &d_sqrt_var[i], + &d_sqrt_var[0], inputSize * inputSize ); CUDA_CHECK(cudaGetLastError()); diff --git a/src/utils/vector.cu b/src/utils/vector.cu index 3d69273..e0ab2b5 100644 --- a/src/utils/vector.cu +++ b/src/utils/vector.cu @@ -7,7 +7,7 @@ using namespace CUDANet; -void Utils::print_vec(float* d_vec, const unsigned int length) { +void Utils::print_vec(const float* d_vec, const unsigned int length) { std::vector h_vec(length); CUDA_CHECK(cudaMemcpy( h_vec.data(), d_vec, sizeof(float) * length, cudaMemcpyDeviceToHost @@ -24,7 +24,7 @@ void Utils::clear(float* d_vec, const unsigned int length) { CUDA_CHECK(cudaMemset(d_vec, 0, sizeof(float) * length)); } -void Utils::max(float* d_vec, float* d_max, const unsigned int length) { +void Utils::max(const float* d_vec, float* d_max, const unsigned int length) { const int grid_size = (length + BLOCK_SIZE - 1) / BLOCK_SIZE; Kernels::max_reduce<<>>(d_vec, d_max, length); @@ -42,7 +42,7 @@ void Utils::max(float* d_vec, float* d_max, const unsigned int length) { } -void Utils::sum(float* d_vec, float* d_sum, const unsigned int length) { +void Utils::sum(const float* d_vec, float* d_sum, const unsigned int length) { const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE; @@ -61,14 +61,14 @@ void Utils::sum(float* d_vec, float* d_sum, const unsigned int length) { } } -void Utils::mean(float* d_vec, float* d_mean, const unsigned int length) { - float sum; - Utils::sum(d_vec, &sum, length); - *d_mean = sum / length; -} +// __device__ float Utils::mean(float* d_vec, const unsigned int length) { +// float sum = 0; +// for (int i = 0; i < length; ++i) { +// sum += d_vec[i]; +// } -void Utils::var(float* d_vec, float* d_mean, float* d_var, const unsigned int length) { +// void Utils::var(float* d_vec, float* d_mean, float* d_var, const unsigned int length) { - // TODO: +// // TODO: -} \ No newline at end of file +// } \ No newline at end of file