diff --git a/include/kernels/activation_functions.cuh b/include/kernels/activation_functions.cuh index fee787e..a74c2cc 100644 --- a/include/kernels/activation_functions.cuh +++ b/include/kernels/activation_functions.cuh @@ -29,45 +29,6 @@ __global__ void relu( const unsigned int len ); -/** - * @brief Softmax activation exponentiation kernel - * - * @param src Pointer to the source array - * @param dst Pointer to the destination array - * @param len Length of the arrays - */ -__global__ void softmax_exp( - const float* __restrict__ src, - float* __restrict__ dst, - const unsigned int len -); - -/** - * @brief - * - * @param d_vector Device pointer to vector - * @param d_output Device pointer to output vector - * @param w Length of the vector - */ -__global__ void softmax_sum( - const float* __restrict__ d_vector, - float* __restrict__ d_output -); - -/** - * @brief Softmax activation function kernel - * - * @param src Pointer to the source array - * @param dst Pointer to the destination array - * @param len Length of the arrays - */ -__global__ void softmax_div( - const float* __restrict__ src, - float* __restrict__ dst, - const float* __restrict__ sum, - const unsigned int len -); - } // namespace CUDANet::Kernels #endif // CUDANET_ACTIVATION_FUNCTIONS_H \ No newline at end of file diff --git a/include/kernels/matmul.cuh b/include/kernels/matmul.cuh index de70f1f..477e5b4 100644 --- a/include/kernels/matmul.cuh +++ b/include/kernels/matmul.cuh @@ -35,17 +35,6 @@ __global__ void vec_vec_add( const unsigned int w ); -/** - * @brief Max reduction kernel - * - * @param d_vector Device pointer to vector - * @param d_output Device pointer to output vector - */ -__global__ void max_reduce( - const float* __restrict__ d_vector, - float* __restrict__ d_output -); - /** * @brief Add scalar to each element of the vector * @@ -56,15 +45,68 @@ __global__ void max_reduce( * @return __global__ */ __global__ void vec_scalar_sub( - const float* __restrict__ d_vector, + const float* __restrict__ d_src, + float* __restrict__ d_out, const float* __restrict__ d_scalar, - float* __restrict__ d_output, - const unsigned int w + const unsigned int len ); +/** + * @brief Softmax activation function kernel + * + * @param src Pointer to the source array + * @param dst Pointer to the destination array + * @param len Length of the arrays + */ +__global__ void vec_scalar_div( + const float* __restrict__ d_src, + float* __restrict__ d_out, + const float* __restrict__ d_scalar, + const unsigned int len +); + +/** + * @brief Softmax activation exponentiation kernel + * + * @param src Pointer to the source array + * @param dst Pointer to the destination array + * @param len Length of the arrays + */ +__global__ void vec_exp( + const float* __restrict__ src, + float* __restrict__ dst, + const unsigned int len +); + +/** + * @brief Max reduction kernel + * + * @param d_vector Device pointer to vector + * @param d_output Device pointer to output vector + */ +__global__ void max_reduce( + const float* __restrict__ d_vector, + float* __restrict__ d_output, + const unsigned int len +); + +/** + * @brief + * + * @param d_vector Device pointer to vector + * @param d_output Device pointer to output vector + * @param len Length of the vector + */ +__global__ void sum_reduce( + const float* __restrict__ d_vector, + float* __restrict__ d_output, + const unsigned int len +); + + __global__ void clear( float* __restrict__ d_vector, - const unsigned int w + const unsigned int len ); } // namespace CUDANet::Kernels diff --git a/src/kernels/activation_functions.cu b/src/kernels/activation_functions.cu index 5864043..72bca66 100644 --- a/src/kernels/activation_functions.cu +++ b/src/kernels/activation_functions.cu @@ -28,51 +28,3 @@ __global__ void Kernels::relu( dst[i] = src[i] < 0.0 ? 0.0 : src[i]; } } - -__global__ void Kernels::softmax_exp( - const float* __restrict__ src, - float* __restrict__ dst, - const unsigned int len -) { - int stride = gridDim.x * blockDim.x; - int tid = blockDim.x * blockIdx.x + threadIdx.x; - - for (int i = tid; i < len; i += stride) { - dst[i] = expf(src[i]); - } -} - -__global__ void Kernels::softmax_sum( - const float* __restrict__ d_vector, - float* __restrict__ d_output -) { - __shared__ float partial_sum[BLOCK_SIZE]; - int i = blockIdx.x * blockDim.x + threadIdx.x; - partial_sum[threadIdx.x] = d_vector[i]; - __syncthreads(); - - for (int s = blockDim.x / 2; s > 0; s >>= 1) { - if (threadIdx.x < s) { - partial_sum[threadIdx.x] += partial_sum[threadIdx.x + s]; - } - __syncthreads(); - } - - if (threadIdx.x == 0) { - d_output[blockIdx.x] = partial_sum[0]; - } -} - -__global__ void Kernels::softmax_div( - const float* __restrict__ src, - float* __restrict__ dst, - const float* __restrict__ sum, - const unsigned int len -) { - int stride = gridDim.x * blockDim.x; - int tid = blockDim.x * blockIdx.x + threadIdx.x; - - for (int i = tid; i < len; i += stride) { - dst[i] = src[i] / sum[0]; - } -} \ No newline at end of file diff --git a/src/kernels/matmul.cu b/src/kernels/matmul.cu index 80e3c00..b43d2a9 100644 --- a/src/kernels/matmul.cu +++ b/src/kernels/matmul.cu @@ -3,6 +3,7 @@ using namespace CUDANet; + __global__ void Kernels::mat_vec_mul( const float* __restrict__ d_matrix, const float* __restrict__ d_vector, @@ -37,6 +38,7 @@ __global__ void Kernels::mat_vec_mul( d_output[tid] = temp; } + __global__ void Kernels::vec_vec_add( const float* __restrict__ d_vector1, const float* __restrict__ d_vector2, @@ -50,14 +52,75 @@ __global__ void Kernels::vec_vec_add( d_output[tid] = d_vector1[tid] + d_vector2[tid]; } + +__global__ void Kernels::vec_scalar_sub( + const float* __restrict__ d_src, + float* __restrict__ d_out, + const float* __restrict__ d_scalar, + const unsigned int len +) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid >= len) { + return; + } + d_out[tid] = d_src[tid] - d_scalar[0]; +} + + +__global__ void Kernels::vec_scalar_div( + const float* __restrict__ d_src, + float* __restrict__ d_out, + const float* __restrict__ d_scalar, + const unsigned int len +) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid >= len) { + return; + } + d_out[tid] = d_src[tid] / d_scalar[0]; +} + + +__global__ void Kernels::vec_exp( + const float* __restrict__ src, + float* __restrict__ dst, + const unsigned int len +) { + int stride = gridDim.x * blockDim.x; + int tid = blockDim.x * blockIdx.x + threadIdx.x; + + for (int i = tid; i < len; i += stride) { + dst[i] = expf(src[i]); + } +} + + +__global__ void Kernels::clear( + float* __restrict__ d_vector, + const unsigned int w +) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid >= w) { + return; + } + d_vector[tid] = 0.0f; +} + + __global__ void Kernels::max_reduce( const float* __restrict__ d_vector, - float* __restrict__ d_output + float* __restrict__ d_output, + const unsigned int len ) { __shared__ float shared_max[BLOCK_SIZE]; int i = blockIdx.x * blockDim.x + threadIdx.x; - shared_max[threadIdx.x] = d_vector[i]; + if (i < len) { + shared_max[threadIdx.x] = d_vector[i]; + } else { + shared_max[threadIdx.x] = -INFINITY; + } + __syncthreads(); for (int s = blockDim.x / 2; s > 0; s >>= 1) { @@ -72,26 +135,30 @@ __global__ void Kernels::max_reduce( } } -__global__ void Kernels::vec_scalar_sub( +__global__ void Kernels::sum_reduce( const float* __restrict__ d_vector, - const float* __restrict__ d_scalar, float* __restrict__ d_output, - const unsigned int w + const unsigned int len ) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - if (tid >= w) { - return; - } - d_output[tid] = d_vector[tid] - d_scalar[0]; -} + __shared__ float partial_sum[BLOCK_SIZE]; + int i = blockIdx.x * blockDim.x + threadIdx.x; -__global__ void Kernels::clear( - float* __restrict__ d_vector, - const unsigned int w -) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - if (tid >= w) { - return; + if (i < len) { + partial_sum[threadIdx.x] = d_vector[i]; + } else { + partial_sum[threadIdx.x] = 0.0f; + } + + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + partial_sum[threadIdx.x] += partial_sum[threadIdx.x + s]; + } + __syncthreads(); + } + + if (threadIdx.x == 0) { + d_output[blockIdx.x] = partial_sum[0]; } - d_vector[tid] = 0.0f; } \ No newline at end of file diff --git a/test/kernels/test_activation_functions.cu b/test/kernels/test_activation_functions.cu index 72b0e73..714db19 100644 --- a/test/kernels/test_activation_functions.cu +++ b/test/kernels/test_activation_functions.cu @@ -4,6 +4,7 @@ #include #include "activation_functions.cuh" +#include "matmul.cuh" #include "cuda_helper.cuh" TEST(ActivationFunctionsTest, SigmoidSanityCheck) { @@ -43,93 +44,24 @@ TEST(ActivationFunctionsTest, SigmoidSanityCheck) { cudaFree(d_input); cudaFree(d_output); + + cudaDeviceReset(); } -TEST(ActivationFunctionsTest, SoftmaxExpTest) { - cudaError_t cudaStatus; +// void print_vec(float* d_vec, int length) { - float input[6] = {22.496f, 36.9006f, 30.9904f, - 28.4213f, 26.4541f, 31.7887f}; +// std::vector h_vec(length); +// CUDA_CHECK(cudaMemcpy( +// h_vec.data(), d_vec, sizeof(float) * length, cudaMemcpyDeviceToHost +// )); - std::vector expected = {5886928896.0f, 1.06102872080384e+16f, - 28771323215872.0f, 2204012904448.0f, - 308226162688.0f, 63922983927808.0f}; +// float sum = 0.0f; - float* d_input; - float* d_output; +// for (int i = 0; i < length; ++i) { +// std::cout << h_vec[i] << ", "; +// sum += h_vec[i]; +// } - cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * 6); - EXPECT_EQ(cudaStatus, cudaSuccess); +// std::cout << std::endl; - cudaStatus = cudaMalloc((void**)&d_output, sizeof(float) * 6); - EXPECT_EQ(cudaStatus, cudaSuccess); - - cudaStatus = - cudaMemcpy(d_input, input, sizeof(float) * 6, cudaMemcpyHostToDevice); - EXPECT_EQ(cudaStatus, cudaSuccess); - - CUDANet::Kernels::softmax_exp<<<1, 6>>>(d_input, d_output, 6); - cudaStatus = cudaDeviceSynchronize(); - EXPECT_EQ(cudaStatus, cudaSuccess); - - std::vector output(6); - - cudaStatus = cudaMemcpy( - output.data(), d_output, sizeof(float) * 6, cudaMemcpyDeviceToHost - ); - EXPECT_EQ(cudaStatus, cudaSuccess); - - for (int i = 0; i < 6; i++) { - EXPECT_NEAR(expected[i], output[i], 1e7); - } - - cudaFree(d_input); - cudaFree(d_output); -} - -TEST(ActivationFunctionsTest, SoftmaxSumTest) { - cudaError_t cudaStatus; - - const int n = 10; - std::vector input(n); - for (int i = 0; i < n; i++) { - input[i] = i; - } - - const float expected = n * (n - 1) / 2; - - float* d_input; - float* d_sum; - - const int gridSize = (n + BLOCK_SIZE - 1) / BLOCK_SIZE; - - cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * n); - EXPECT_EQ(cudaStatus, cudaSuccess); - - cudaStatus = cudaMalloc((void**)&d_sum, sizeof(float) * n); - EXPECT_EQ(cudaStatus, cudaSuccess); - - cudaStatus = - cudaMemcpy(d_input, input.data(), sizeof(float) * n, cudaMemcpyHostToDevice); - EXPECT_EQ(cudaStatus, cudaSuccess); - - CUDANet::Kernels::softmax_sum<<>>( - d_input, d_sum - ); - - CUDANet::Kernels::softmax_sum<<<1, BLOCK_SIZE>>>( - d_sum, d_sum - ); - - CUDANet::Kernels::softmax_sum<<<1, BLOCK_SIZE>>>( - d_sum, d_sum - ); - - std::vector sum(n); - cudaStatus = cudaMemcpy( - sum.data(), d_sum, sizeof(float) * n, cudaMemcpyDeviceToHost - ); - EXPECT_EQ(cudaStatus, cudaSuccess); - - EXPECT_FLOAT_EQ(expected, sum[0]); -} \ No newline at end of file +// } \ No newline at end of file diff --git a/test/kernels/test_matmul.cu b/test/kernels/test_matmul.cu index 9763f4a..bb42927 100644 --- a/test/kernels/test_matmul.cu +++ b/test/kernels/test_matmul.cu @@ -73,33 +73,141 @@ TEST(MatMulTest, MatVecMulTest) { TEST(MatMulTest, MaxReduceTest) { cudaError_t cudaStatus; - std::vector input = {0.643f, 0.912f, 0.723f, 0.587f, 0.155f, 0.932f, 0.391f, 0.279f, 0.846f, 0.788f}; + const int n = 1 << 16; + + std::vector input(n); + for (int i = 0; i < n; i++) { + input[i] = i; + } float* d_input; float* d_output; - cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * 10); + cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * n); EXPECT_EQ(cudaStatus, cudaSuccess); cudaStatus = cudaMalloc((void**)&d_output, sizeof(float)); EXPECT_EQ(cudaStatus, cudaSuccess); - cudaStatus = cudaMemcpy(d_input, input.data(), sizeof(float) * 10, cudaMemcpyHostToDevice); + cudaStatus = cudaMemcpy(d_input, input.data(), sizeof(float) * n, cudaMemcpyHostToDevice); EXPECT_EQ(cudaStatus, cudaSuccess); - const int grid_size = (10 + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int grid_size = (n + BLOCK_SIZE - 1) / BLOCK_SIZE; - CUDANet::Kernels::max_reduce<<>>(d_input, d_output); - CUDANet::Kernels::max_reduce<<<1, BLOCK_SIZE>>>(d_output, d_output); + CUDANet::Kernels::max_reduce<<>>(d_input, d_output, n); - std::vector output(10); + int remaining = grid_size; + while (remaining > 1) { + int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE; + CUDANet::Kernels::max_reduce<<>>(d_output, d_output, remaining); + remaining = blocks_needed; + } + + std::vector output(n); cudaStatus = cudaMemcpy(output.data(), d_output, sizeof(float), cudaMemcpyDeviceToHost); EXPECT_EQ(cudaStatus, cudaSuccess); - EXPECT_EQ(output[0], 0.932f); + EXPECT_EQ(output[0], 65535.0f); cudaFree(d_input); cudaFree(d_output); + cudaDeviceReset(); +} + +TEST(MatMulTest, VecExpTest) { + cudaError_t cudaStatus; + + float input[6] = {22.496f, 36.9006f, 30.9904f, + 28.4213f, 26.4541f, 31.7887f}; + + std::vector expected = {5886928896.0f, 1.06102872080384e+16f, + 28771323215872.0f, 2204012904448.0f, + 308226162688.0f, 63922983927808.0f}; + + float* d_input; + float* d_output; + + cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * 6); + EXPECT_EQ(cudaStatus, cudaSuccess); + + cudaStatus = cudaMalloc((void**)&d_output, sizeof(float) * 6); + EXPECT_EQ(cudaStatus, cudaSuccess); + + cudaStatus = + cudaMemcpy(d_input, input, sizeof(float) * 6, cudaMemcpyHostToDevice); + EXPECT_EQ(cudaStatus, cudaSuccess); + + CUDANet::Kernels::vec_exp<<<1, 6>>>(d_input, d_output, 6); + cudaStatus = cudaDeviceSynchronize(); + EXPECT_EQ(cudaStatus, cudaSuccess); + + std::vector output(6); + + cudaStatus = cudaMemcpy( + output.data(), d_output, sizeof(float) * 6, cudaMemcpyDeviceToHost + ); + EXPECT_EQ(cudaStatus, cudaSuccess); + + for (int i = 0; i < 6; i++) { + EXPECT_NEAR(expected[i], output[i], 1e7); + } + + cudaFree(d_input); + cudaFree(d_output); + + cudaDeviceReset(); +} + +TEST(MatMulTest, SumReduceTest) { + cudaError_t cudaStatus; + + const int n = 1 << 16; + + std::vector input(n); + for (int i = 0; i < n; i++) { + input[i] = 1.0f; + } + + const float expected = n; + + float* d_input = nullptr; + float* d_sum = nullptr; + + const int gridSize = (n + BLOCK_SIZE - 1) / BLOCK_SIZE; + + cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * n); + EXPECT_EQ(cudaStatus, cudaSuccess); + + cudaStatus = cudaMalloc((void**)&d_sum, sizeof(float) * n); + EXPECT_EQ(cudaStatus, cudaSuccess); + + cudaStatus = + cudaMemcpy(d_input, input.data(), sizeof(float) * n, cudaMemcpyHostToDevice); + EXPECT_EQ(cudaStatus, cudaSuccess); + + CUDANet::Kernels::sum_reduce<<>>( + d_input, d_sum, n + ); + + int remaining = gridSize; + while (remaining > 1) { + std::cout << remaining << std::endl; + int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE; + CUDANet::Kernels::sum_reduce<<>>(d_sum, d_sum, remaining); + remaining = blocks_needed; + } + + std::vector sum(n); + cudaStatus = cudaMemcpy( + sum.data(), d_sum, sizeof(float) * n, cudaMemcpyDeviceToHost + ); + EXPECT_EQ(cudaStatus, cudaSuccess); + + EXPECT_FLOAT_EQ(expected, sum[0]); + + cudaFree(d_input); + cudaFree(d_sum); + cudaDeviceReset(); } \ No newline at end of file