Use shared memory for mat vec mul kernel

This commit is contained in:
2024-03-13 22:13:11 +01:00
parent 09480e42e5
commit 77004c16be
4 changed files with 77 additions and 7 deletions

View File

@@ -10,16 +10,22 @@ __global__ void Kernels::mat_vec_mul(
int tid = blockDim.x * blockIdx.x + threadIdx.x; int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= w * h) { extern __shared__ float shared[];
return;
if (tid < w) {
shared[tid] = d_vector[tid];
} }
d_output[tid] = 0.0f; __syncthreads();
for (int i = 0; i < w; i++) { if (tid < h) {
d_output[tid] += d_matrix[tid * w + i] * d_vector[i]; d_output[tid] = 0.0f;
#pragma unroll
for (int i = 0; i < w; i++) {
d_output[tid] += d_matrix[tid * w + i] * shared[i];
}
} }
} }
__global__ void Kernels::vec_vec_add( __global__ void Kernels::vec_vec_add(

View File

@@ -51,7 +51,7 @@ void Layers::Dense::initializeBiases() {
} }
float* Layers::Dense::forward(const float* d_input) { float* Layers::Dense::forward(const float* d_input) {
Kernels::mat_vec_mul<<<1, outputSize>>>( Kernels::mat_vec_mul<<<1, std::max(inputSize, outputSize), sizeof(float) * inputSize>>>(
d_weights, d_input, d_output, inputSize, outputSize d_weights, d_input, d_output, inputSize, outputSize
); );

View File

@@ -7,6 +7,7 @@ add_executable(test_main
layers/test_input.cu layers/test_input.cu
kernels/test_activations.cu kernels/test_activations.cu
kernels/test_padding.cu kernels/test_padding.cu
kernels/test_matmul.cu
) )
target_link_libraries(test_main ${GTEST_BOTH_LIBRARIES} CUDANet) target_link_libraries(test_main ${GTEST_BOTH_LIBRARIES} CUDANet)

View File

@@ -0,0 +1,63 @@
#include <cuda_runtime_api.h>
#include <gtest/gtest.h>
#include <vector>
#include "matmul.cuh"
TEST(MatMulTest, MatVecMulTest) {
cudaError_t cudaStatus;
int w = 10;
int h = 5;
float* d_matrix;
float* d_vector;
float* d_output;
cudaStatus = cudaMalloc((void**)&d_matrix, sizeof(float) * w * h);
EXPECT_EQ(cudaStatus, cudaSuccess);
cudaStatus = cudaMalloc((void**)&d_vector, sizeof(float) * w);
EXPECT_EQ(cudaStatus, cudaSuccess);
cudaStatus = cudaMalloc((void**)&d_output, sizeof(float) * h);
EXPECT_EQ(cudaStatus, cudaSuccess);
std::vector<float> matrix = {
0.643f, 0.912f, 0.723f, 0.587f, 0.155f, 0.932f, 0.391f, 0.279f, 0.846f, 0.788f,
0.641f, 0.445f, 0.528f, 0.316f, 0.247f, 0.181f, 0.549f, 0.328f, 0.919f, 0.405f,
0.733f, 0.287f, 0.901f, 0.602f, 0.816f, 0.495f, 0.797f, 0.210f, 0.305f, 0.613f,
0.178f, 0.856f, 0.724f, 0.263f, 0.559f, 0.677f, 0.193f, 0.389f, 0.488f, 0.848f,
0.121f, 0.734f, 0.587f, 0.904f, 0.312f, 0.672f, 0.807f, 0.478f, 0.581f, 0.964f
};
std::vector<float> vector = {
0.643f, 0.912f, 0.723f, 0.587f, 0.155f, 0.932f, 0.391f, 0.279f, 0.846f, 0.788f
};
cudaStatus = cudaMemcpy(d_matrix, matrix.data(), sizeof(float) * w * h, cudaMemcpyHostToDevice);
EXPECT_EQ(cudaStatus, cudaSuccess);
cudaStatus = cudaMemcpy(d_vector, vector.data(), sizeof(float) * w, cudaMemcpyHostToDevice);
EXPECT_EQ(cudaStatus, cudaSuccess);
int THREADS_PER_BLOCK = std::max(w, h);
int BLOCKS = 1;
Kernels::mat_vec_mul<<<BLOCKS, THREADS_PER_BLOCK, sizeof(float) * w>>>(d_matrix, d_vector, d_output, w, h);
cudaStatus = cudaDeviceSynchronize();
EXPECT_EQ(cudaStatus, cudaSuccess);
std::vector<float> output_gpu(h);
cudaStatus = cudaMemcpy(output_gpu.data(), d_output, sizeof(float) * h, cudaMemcpyDeviceToHost);
EXPECT_EQ(cudaStatus, cudaSuccess);
for (int i = 0; i < h; i++) {
float sum = 0.0f;
for (int j = 0; j < w; j++) {
sum += matrix[i * w + j] * vector[j];
}
EXPECT_NEAR(sum, output_gpu[i], 1e-5);
}
}