From 4b9d123e945dfcceaad7d64e5e978c62e75a9140 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 11 Apr 2024 22:22:33 +0200 Subject: [PATCH] Implement device vector utils --- include/kernels/matmul.cuh | 6 ---- include/utils/vector.cuh | 44 ++++++++++++++++++++++++++++ src/kernels/matmul.cu | 12 -------- src/layers/dense.cu | 4 +++ src/utils/vector.cu | 58 +++++++++++++++++++++++++++++++++++++ test/kernels/test_matmul.cu | 4 ++- 6 files changed, 109 insertions(+), 19 deletions(-) create mode 100644 include/utils/vector.cuh create mode 100644 src/utils/vector.cu diff --git a/include/kernels/matmul.cuh b/include/kernels/matmul.cuh index 477e5b4..3784251 100644 --- a/include/kernels/matmul.cuh +++ b/include/kernels/matmul.cuh @@ -103,12 +103,6 @@ __global__ void sum_reduce( const unsigned int len ); - -__global__ void clear( - float* __restrict__ d_vector, - const unsigned int len -); - } // namespace CUDANet::Kernels #endif // CUDANET_MATMUL_H \ No newline at end of file diff --git a/include/utils/vector.cuh b/include/utils/vector.cuh new file mode 100644 index 0000000..0ec913a --- /dev/null +++ b/include/utils/vector.cuh @@ -0,0 +1,44 @@ +#ifndef CUDANET_VECTOR_H +#define CUDANET_VECTOR_H + +namespace CUDANet::Utils { + + +/** + * @brief Utility function that prints a vector + * + * @param d_vec Pointer to the vector on device + * @param length Length of the vector + */ +void print_vec(float *d_vec, const unsigned int length); + +/** + * @brief Utility function that clears a vector + * + * @param d_vector Pointer to the vector on device + * @param len Length of the vector + */ +void clear(float *d_vector, const unsigned int len); + + +/** + * @brief Utility function that returns the sum of a vector + * + * @param d_vec Pointer to the vector + * @param length Length of the vector + */ +void sum(float *d_vec, float *d_sum, const unsigned int length); + + +/** + * @brief Utility function that returns the max of a vector + * + * @param d_vec Pointer to the vector + * @param length Length of the vector + */ +void max(float *d_vec, float *d_max, const unsigned int length); + + +} // namespace CUDANet::Utils + +#endif // CUDANET_VECTOR_H \ No newline at end of file diff --git a/src/kernels/matmul.cu b/src/kernels/matmul.cu index b43d2a9..d956cf7 100644 --- a/src/kernels/matmul.cu +++ b/src/kernels/matmul.cu @@ -95,18 +95,6 @@ __global__ void Kernels::vec_exp( } -__global__ void Kernels::clear( - float* __restrict__ d_vector, - const unsigned int w -) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - if (tid >= w) { - return; - } - d_vector[tid] = 0.0f; -} - - __global__ void Kernels::max_reduce( const float* __restrict__ d_vector, float* __restrict__ d_output, diff --git a/src/layers/dense.cu b/src/layers/dense.cu index 751b5a9..0ab538d 100644 --- a/src/layers/dense.cu +++ b/src/layers/dense.cu @@ -5,6 +5,7 @@ #include #include +#include "vector.cuh" #include "activation.cuh" #include "cuda_helper.cuh" #include "dense.cuh" @@ -63,6 +64,9 @@ void Dense::initializeBiases() { } float* Dense::forward(const float* d_input) { + + CUDANet::Utils::clear(d_output, outputSize); + Kernels::mat_vec_mul<<>>( d_weights, d_input, d_output, inputSize, outputSize ); diff --git a/src/utils/vector.cu b/src/utils/vector.cu new file mode 100644 index 0000000..6fb7e23 --- /dev/null +++ b/src/utils/vector.cu @@ -0,0 +1,58 @@ +#include +#include + +#include "vector.cuh" +#include "matmul.cuh" +#include "cuda_helper.cuh" + +using namespace CUDANet; + +void Utils::print_vec(float* d_vec, const unsigned int length) { + std::vector h_vec(length); + CUDA_CHECK(cudaMemcpy( + h_vec.data(), d_vec, sizeof(float) * length, cudaMemcpyDeviceToHost + )); + + for (int i = 0; i < length; ++i) { + std::cout << h_vec[i] << ", "; + } + + std::cout << std::endl; +} + +void Utils::clear(float* d_vec, const unsigned int length) { + CUDA_CHECK(cudaMemset(d_vec, 0, sizeof(float) * length)); +} + +void Utils::max(float* d_vec, float* d_max, const unsigned int length) { + + const int grid_size = (length + BLOCK_SIZE - 1) / BLOCK_SIZE; + + Kernels::max_reduce<<>>(d_vec, d_max, length); + + int remaining = grid_size; + while (remaining > 1) { + int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE; + CUDANet::Kernels::max_reduce<<>>(d_max, d_max, remaining); + remaining = blocks_needed; + } + +} + +void Utils::sum(float* d_vec, float* d_sum, const unsigned int length) { + + const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE; + + + CUDANet::Kernels::sum_reduce<<>>( + d_vec, d_sum, length + ); + + int remaining = gridSize; + while (remaining > 1) { + std::cout << remaining << std::endl; + int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE; + CUDANet::Kernels::sum_reduce<<>>(d_sum, d_sum, remaining); + remaining = blocks_needed; + } +} \ No newline at end of file diff --git a/test/kernels/test_matmul.cu b/test/kernels/test_matmul.cu index bb42927..d91fb9a 100644 --- a/test/kernels/test_matmul.cu +++ b/test/kernels/test_matmul.cu @@ -4,6 +4,7 @@ #include #include "cuda_helper.cuh" +#include "vector.cuh" #include "matmul.cuh" TEST(MatMulTest, MatVecMulTest) { @@ -45,7 +46,7 @@ TEST(MatMulTest, MatVecMulTest) { int THREADS_PER_BLOCK = std::max(w, h); int BLOCKS = 1; - CUDANet::Kernels::clear<<>>(d_output, h); + CUDANet::Utils::clear(d_output, h); CUDANet::Kernels::mat_vec_mul<<>>(d_matrix, d_vector, d_output, w, h); cudaStatus = cudaDeviceSynchronize(); @@ -198,6 +199,7 @@ TEST(MatMulTest, SumReduceTest) { remaining = blocks_needed; } + std::vector sum(n); cudaStatus = cudaMemcpy( sum.data(), d_sum, sizeof(float) * n, cudaMemcpyDeviceToHost