From 0170afaf3f4a18628b1c31f259cc3a31f21db66e Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sun, 21 Apr 2024 12:19:19 +0200 Subject: [PATCH] Improve cuda error handling --- src/layers/activation.cu | 7 ++++++- src/layers/add.cu | 2 ++ src/layers/avg_pooling.cu | 2 ++ src/layers/concat.cu | 3 +++ src/layers/conv2d.cu | 11 ++++++++++- src/layers/dense.cu | 9 +++++++-- src/layers/input.cu | 8 +------- src/layers/max_pooling.cu | 2 ++ src/layers/output.cu | 1 + src/utils/vector.cu | 6 ++++++ 10 files changed, 40 insertions(+), 11 deletions(-) diff --git a/src/layers/activation.cu b/src/layers/activation.cu index fc71bd2..3341483 100644 --- a/src/layers/activation.cu +++ b/src/layers/activation.cu @@ -38,10 +38,12 @@ void Activation::activate(float* d_input) { Kernels::sigmoid<<>>( d_input, d_input, length ); + CUDA_CHECK(cudaGetLastError()); break; case RELU: Kernels::relu<<>>(d_input, d_input, length); + CUDA_CHECK(cudaGetLastError()); break; case SOFTMAX: @@ -52,11 +54,13 @@ void Activation::activate(float* d_input) { Kernels::vec_scalar_sub<<>>( d_input, d_input, d_max, length ); + CUDA_CHECK(cudaGetLastError()); // Compute exponentials Kernels::vec_exp<<>>( d_input, d_input, length ); + CUDA_CHECK(cudaGetLastError()); // Find sum Utils::sum(d_input, d_softmax_sum, length); @@ -64,6 +68,7 @@ void Activation::activate(float* d_input) { Kernels::vec_scalar_div<<>>( d_input, d_input, d_softmax_sum, length ); + CUDA_CHECK(cudaGetLastError()); break; @@ -71,6 +76,6 @@ void Activation::activate(float* d_input) { break; } - cudaDeviceSynchronize(); + CUDA_CHECK(cudaDeviceSynchronize()); } diff --git a/src/layers/add.cu b/src/layers/add.cu index 7539e5f..2c672d9 100644 --- a/src/layers/add.cu +++ b/src/layers/add.cu @@ -25,5 +25,7 @@ void Add::forward(const float* d_inputA, const float* d_inputB) { Kernels::vec_vec_add<<>>( d_inputA, d_inputB, d_output, inputSize ); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); } \ No newline at end of file diff --git a/src/layers/avg_pooling.cu b/src/layers/avg_pooling.cu index b37bb9b..e1cb40a 100644 --- a/src/layers/avg_pooling.cu +++ b/src/layers/avg_pooling.cu @@ -42,8 +42,10 @@ float* AvgPooling2D::forward(const float* d_input) { Kernels::avg_pooling<<>>( d_input, d_output, inputSize, outputSize, nChannels, poolingSize, stride ); + CUDA_CHECK(cudaGetLastError()); activation.activate(d_output); + CUDA_CHECK(cudaDeviceSynchronize()); return d_output; } \ No newline at end of file diff --git a/src/layers/concat.cu b/src/layers/concat.cu index 94284e6..f353f7f 100644 --- a/src/layers/concat.cu +++ b/src/layers/concat.cu @@ -23,10 +23,13 @@ float* Concat::forward(const float* d_input_A, const float* d_input_B) { CUDA_CHECK(cudaMemcpy( d_output, d_input_A, sizeof(float) * inputASize, cudaMemcpyDeviceToDevice )); + CUDA_CHECK(cudaMemcpy( d_output + inputASize, d_input_B, sizeof(float) * inputBSize, cudaMemcpyDeviceToDevice )); + CUDA_CHECK(cudaDeviceSynchronize()); + return d_output; } diff --git a/src/layers/conv2d.cu b/src/layers/conv2d.cu index 4d7bbcd..dec60ab 100644 --- a/src/layers/conv2d.cu +++ b/src/layers/conv2d.cu @@ -3,6 +3,11 @@ #include "convolution.cuh" #include "cuda_helper.cuh" #include "matmul.cuh" +#include "layer.cuh" +#include "vector.cuh" + +#include +#include using namespace CUDANet::Layers; @@ -100,6 +105,7 @@ void Conv2d::toCuda() { } float* Conv2d::forward(const float* d_input) { + // Convolve dim3 block(8,8,8); dim3 grid( @@ -108,11 +114,14 @@ float* Conv2d::forward(const float* d_input) { (numFilters + block.z - 1) / block.z ); + CUDANet::Utils::clear(d_output, outputSize * outputSize * numFilters); + Kernels::convolution<<>>( d_input, d_weights, d_biases, d_output, inputSize, inputChannels, paddingSize, kernelSize, stride, numFilters, outputSize ); - + CUDA_CHECK(cudaGetLastError()); + // Apply activation activation.activate(d_output); diff --git a/src/layers/dense.cu b/src/layers/dense.cu index 31abcf5..dae804d 100644 --- a/src/layers/dense.cu +++ b/src/layers/dense.cu @@ -65,18 +65,23 @@ void Dense::initializeBiases() { float* Dense::forward(const float* d_input) { - CUDANet::Utils::clear(d_output, outputSize); + // CUDANet::Utils::clear(d_output, outputSize); + + // CUDA_CHECK(cudaPeekAtLastError()); + + std::cout << "Dense::forward" << std::endl; Kernels::mat_vec_mul<<>>( d_weights, d_input, d_output, inputSize, outputSize ); + CUDA_CHECK(cudaPeekAtLastError()); Kernels::vec_vec_add<<>>( d_biases, d_output, d_output, outputSize ); + CUDA_CHECK(cudaPeekAtLastError()); activation.activate(d_output); - CUDA_CHECK(cudaDeviceSynchronize()); return d_output; diff --git a/src/layers/input.cu b/src/layers/input.cu index 2e5f157..59ec381 100644 --- a/src/layers/input.cu +++ b/src/layers/input.cu @@ -12,17 +12,11 @@ Input::~Input() { cudaFree(d_output); } -/* -Copies host input to device d_output - -Args - const float* input Host pointer to input data - float* d_output Device pointer to input data copied to device -*/ float* Input::forward(const float* input) { CUDA_CHECK(cudaMemcpy( d_output, input, sizeof(float) * inputSize, cudaMemcpyHostToDevice )); + CUDA_CHECK(cudaDeviceSynchronize()); return d_output; } diff --git a/src/layers/max_pooling.cu b/src/layers/max_pooling.cu index f6ae351..a8ed7c6 100644 --- a/src/layers/max_pooling.cu +++ b/src/layers/max_pooling.cu @@ -45,8 +45,10 @@ float* MaxPooling2D::forward(const float* d_input) { Kernels::max_pooling<<>>( d_input, d_output, inputSize, outputSize, nChannels, poolingSize, stride ); + CUDA_CHECK(cudaGetLastError()); activation.activate(d_output); + CUDA_CHECK(cudaDeviceSynchronize()); return d_output; } \ No newline at end of file diff --git a/src/layers/output.cu b/src/layers/output.cu index f032376..a37afa5 100644 --- a/src/layers/output.cu +++ b/src/layers/output.cu @@ -17,6 +17,7 @@ float* Output::forward(const float* input) { CUDA_CHECK(cudaMemcpy( h_output, input, sizeof(float) * inputSize, cudaMemcpyDeviceToHost )); + CUDA_CHECK(cudaDeviceSynchronize()); return h_output; } \ No newline at end of file diff --git a/src/utils/vector.cu b/src/utils/vector.cu index ea97be8..d84a4f2 100644 --- a/src/utils/vector.cu +++ b/src/utils/vector.cu @@ -29,11 +29,14 @@ void Utils::max(float* d_vec, float* d_max, const unsigned int length) { const int grid_size = (length + BLOCK_SIZE - 1) / BLOCK_SIZE; Kernels::max_reduce<<>>(d_vec, d_max, length); + CUDA_CHECK(cudaGetLastError()); int remaining = grid_size; while (remaining > 1) { int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE; CUDANet::Kernels::max_reduce<<>>(d_max, d_max, remaining); + CUDA_CHECK(cudaGetLastError()); + remaining = blocks_needed; } @@ -47,11 +50,14 @@ void Utils::sum(float* d_vec, float* d_sum, const unsigned int length) { CUDANet::Kernels::sum_reduce<<>>( d_vec, d_sum, length ); + CUDA_CHECK(cudaGetLastError()); int remaining = gridSize; while (remaining > 1) { int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE; CUDANet::Kernels::sum_reduce<<>>(d_sum, d_sum, remaining); + CUDA_CHECK(cudaGetLastError()); + remaining = blocks_needed; } } \ No newline at end of file