diff --git a/src/kernels/matmul.cu b/src/kernels/matmul.cu index d956cf7..2441dbe 100644 --- a/src/kernels/matmul.cu +++ b/src/kernels/matmul.cu @@ -3,7 +3,6 @@ using namespace CUDANet; - __global__ void Kernels::mat_vec_mul( const float* __restrict__ d_matrix, const float* __restrict__ d_vector, @@ -13,32 +12,17 @@ __global__ void Kernels::mat_vec_mul( ) { int tid = blockDim.x * blockIdx.x + threadIdx.x; - __shared__ float shared[BLOCK_SIZE]; + if (tid < h) { + float temp = 0.0f; - float temp = 0.0f; - -#pragma unroll - for (unsigned int i = 0; i < (w + BLOCK_SIZE - 1) / BLOCK_SIZE; i++) { - if (i * BLOCK_SIZE + threadIdx.x < w) { - shared[threadIdx.x] = d_vector[i * BLOCK_SIZE + threadIdx.x]; - } else { - shared[threadIdx.x] = 0.0f; + for (unsigned int j = 0; j < w; j++) { + temp += d_matrix[tid * w + j] * d_vector[j]; } - __syncthreads(); - -#pragma unroll - for (unsigned int j = 0; j < BLOCK_SIZE; j++) { - temp += d_matrix[tid * w + i * BLOCK_SIZE + j] * shared[j]; - } - - __syncthreads(); + d_output[tid] = temp; } - - d_output[tid] = temp; } - __global__ void Kernels::vec_vec_add( const float* __restrict__ d_vector1, const float* __restrict__ d_vector2, @@ -52,7 +36,6 @@ __global__ void Kernels::vec_vec_add( d_output[tid] = d_vector1[tid] + d_vector2[tid]; } - __global__ void Kernels::vec_scalar_sub( const float* __restrict__ d_src, float* __restrict__ d_out, @@ -66,21 +49,19 @@ __global__ void Kernels::vec_scalar_sub( d_out[tid] = d_src[tid] - d_scalar[0]; } - __global__ void Kernels::vec_scalar_div( const float* __restrict__ d_src, float* __restrict__ d_out, const float* __restrict__ d_scalar, const unsigned int len ) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; + int tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid >= len) { return; } d_out[tid] = d_src[tid] / d_scalar[0]; } - __global__ void Kernels::vec_exp( const float* __restrict__ src, float* __restrict__ dst, @@ -107,7 +88,7 @@ __global__ void Kernels::max_reduce( shared_max[threadIdx.x] = d_vector[i]; } else { shared_max[threadIdx.x] = -INFINITY; - } + } __syncthreads(); @@ -129,7 +110,7 @@ __global__ void Kernels::sum_reduce( const unsigned int len ) { __shared__ float partial_sum[BLOCK_SIZE]; - int i = blockIdx.x * blockDim.x + threadIdx.x; + int i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { partial_sum[threadIdx.x] = d_vector[i]; diff --git a/src/utils/vector.cu b/src/utils/vector.cu index 70fbaf9..e3db8bc 100644 --- a/src/utils/vector.cu +++ b/src/utils/vector.cu @@ -27,17 +27,7 @@ void Utils::clear(float* d_vec, const unsigned int length) { void Utils::max(float* d_vec, float* d_max, const unsigned int length) { const int grid_size = (length + BLOCK_SIZE - 1) / BLOCK_SIZE; - - std::cout << "grid_size: " << grid_size << ", length: " << length << std::endl; - CUDA_CHECK(cudaGetLastError()); - Kernels::max_reduce<<>>(d_vec, d_max, length); - - std::cout << "input: " << std::endl; - print_vec(d_vec, length); - std::cout << "max: " << std::endl; - print_vec(d_max, length); - CUDA_CHECK(cudaGetLastError()); int remaining = grid_size; diff --git a/test/kernels/test_matmul.cu b/test/kernels/test_matmul.cu index 2364cc3..de2fd0d 100644 --- a/test/kernels/test_matmul.cu +++ b/test/kernels/test_matmul.cu @@ -43,12 +43,12 @@ TEST(MatMulTest, MatVecMulTest) { cudaStatus = cudaMemcpy(d_vector, vector.data(), sizeof(float) * w, cudaMemcpyHostToDevice); EXPECT_EQ(cudaStatus, cudaSuccess); - int THREADS_PER_BLOCK = std::max(w, h); - int BLOCKS = 1; + int grid_size = (std::max(w, h) + BLOCK_SIZE - 1) / BLOCK_SIZE; + CUDANet::Utils::clear(d_output, h); - CUDANet::Kernels::mat_vec_mul<<>>(d_matrix, d_vector, d_output, w, h); + CUDANet::Kernels::mat_vec_mul<<>>(d_matrix, d_vector, d_output, w, h); cudaStatus = cudaDeviceSynchronize(); EXPECT_EQ(cudaStatus, cudaSuccess); @@ -87,7 +87,7 @@ TEST(MatMulTest, MaxReduceTest) { cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * n); EXPECT_EQ(cudaStatus, cudaSuccess); - cudaStatus = cudaMalloc((void**)&d_output, sizeof(float)); + cudaStatus = cudaMalloc((void**)&d_output, sizeof(float) * n); EXPECT_EQ(cudaStatus, cudaSuccess); cudaStatus = cudaMemcpy(d_input, input.data(), sizeof(float) * n, cudaMemcpyHostToDevice); diff --git a/test/layers/test_activation.cu b/test/layers/test_activation.cu index 7226618..7ed6509 100644 --- a/test/layers/test_activation.cu +++ b/test/layers/test_activation.cu @@ -3,6 +3,7 @@ #include #include + TEST(ActivationTest, SoftmaxTest1) { const int inputSize = 5; cudaError_t cudaStatus; @@ -39,6 +40,9 @@ TEST(ActivationTest, SoftmaxTest1) { cudaFree(d_input); cudaDeviceReset(); + + cudaStatus = cudaGetLastError(); + EXPECT_EQ(cudaStatus, cudaSuccess); } TEST(ActivationTest, SoftmaxTest2) {