diff --git a/include/utils/vector.cuh b/include/utils/vector.cuh index ca343a3..24e36af 100644 --- a/include/utils/vector.cuh +++ b/include/utils/vector.cuh @@ -49,14 +49,14 @@ void max(const float *d_vec, float *d_max, const unsigned int length); */ void mean(const float *d_vec, float *d_mean, float *d_length, int length); -// /** -// * @brief Compute the variance of a vector -// * -// * @param d_vec -// * @param d_var -// * @param length -// */ -// void var(float *d_vec, float *d_var, const unsigned int length); +/** + * @brief Compute the variance of a vector + * + * @param d_vec + * @param d_var + * @param length + */ +void var(float *d_vec, float *d_var, float *d_length, const unsigned int length); } // namespace CUDANet::Utils diff --git a/src/layers/batch_norm.cu b/src/layers/batch_norm.cu index b9f9fa9..70e532a 100644 --- a/src/layers/batch_norm.cu +++ b/src/layers/batch_norm.cu @@ -142,30 +142,12 @@ float *BatchNorm::forward(const float *d_input) { CUDA_CHECK(cudaGetLastError()); // Compute variance - // Square differences of input - mean - Kernels::vec_vec_mul<<>>( + Utils::var( d_mean_sub, - d_mean_sub, - d_sqrt_var, - inputSize * inputSize - ); - CUDA_CHECK(cudaGetLastError()); - - // Sum over all differences - Utils::sum( - d_sqrt_var, - d_sqrt_var, - inputSize * inputSize - ); - - // Divide by difference sum / length -> variance - Kernels::vec_scalar_div<<>>( - d_sqrt_var, d_sqrt_var, d_length, inputSize * inputSize ); - CUDA_CHECK(cudaGetLastError()); // Add epsilon to variance to avoid division by zero Kernels::vec_scalar_add<<>>( @@ -193,6 +175,7 @@ float *BatchNorm::forward(const float *d_input) { ); CUDA_CHECK(cudaGetLastError()); + // Multiply by weights Kernels::vec_scalar_mul<<>>( d_output + i * inputSize * inputSize, d_output + i * inputSize * inputSize, @@ -201,6 +184,7 @@ float *BatchNorm::forward(const float *d_input) { ); CUDA_CHECK(cudaGetLastError()); + // Add biases Kernels::vec_scalar_add<<>>( d_output + i * inputSize * inputSize, d_output + i * inputSize * inputSize, diff --git a/src/utils/vector.cu b/src/utils/vector.cu index d7dcf72..9dfb950 100644 --- a/src/utils/vector.cu +++ b/src/utils/vector.cu @@ -73,4 +73,35 @@ void Utils::mean(const float* d_vec, float* d_mean, float *d_length, int length) ); CUDA_CHECK(cudaGetLastError()); +} + + +void Utils::var(float* d_vec, float* d_var, float *d_length, const unsigned int length) { + + const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE; + + Kernels::vec_vec_mul<<>>( + d_vec, + d_vec, + d_var, + length + ); + CUDA_CHECK(cudaGetLastError()); + + // Sum over all differences + Utils::sum( + d_var, + d_var, + length + ); + + // Divide by difference sum / length -> variance + Kernels::vec_scalar_div<<>>( + d_var, + d_var, + d_length, + length + ); + CUDA_CHECK(cudaGetLastError()); + } \ No newline at end of file