mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Improve cuda error handling
This commit is contained in:
@@ -38,10 +38,12 @@ void Activation::activate(float* d_input) {
|
|||||||
Kernels::sigmoid<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::sigmoid<<<gridSize, BLOCK_SIZE>>>(
|
||||||
d_input, d_input, length
|
d_input, d_input, length
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case RELU:
|
case RELU:
|
||||||
Kernels::relu<<<gridSize, BLOCK_SIZE>>>(d_input, d_input, length);
|
Kernels::relu<<<gridSize, BLOCK_SIZE>>>(d_input, d_input, length);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
break;
|
break;
|
||||||
case SOFTMAX:
|
case SOFTMAX:
|
||||||
|
|
||||||
@@ -52,11 +54,13 @@ void Activation::activate(float* d_input) {
|
|||||||
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
|
||||||
d_input, d_input, d_max, length
|
d_input, d_input, d_max, length
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
// Compute exponentials
|
// Compute exponentials
|
||||||
Kernels::vec_exp<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::vec_exp<<<gridSize, BLOCK_SIZE>>>(
|
||||||
d_input, d_input, length
|
d_input, d_input, length
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
// Find sum
|
// Find sum
|
||||||
Utils::sum(d_input, d_softmax_sum, length);
|
Utils::sum(d_input, d_softmax_sum, length);
|
||||||
@@ -64,6 +68,7 @@ void Activation::activate(float* d_input) {
|
|||||||
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
|
||||||
d_input, d_input, d_softmax_sum, length
|
d_input, d_input, d_softmax_sum, length
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
break;
|
break;
|
||||||
|
|
||||||
@@ -71,6 +76,6 @@ void Activation::activate(float* d_input) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
cudaDeviceSynchronize();
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -25,5 +25,7 @@ void Add::forward(const float* d_inputA, const float* d_inputB) {
|
|||||||
Kernels::vec_vec_add<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::vec_vec_add<<<gridSize, BLOCK_SIZE>>>(
|
||||||
d_inputA, d_inputB, d_output, inputSize
|
d_inputA, d_inputB, d_output, inputSize
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -42,8 +42,10 @@ float* AvgPooling2D::forward(const float* d_input) {
|
|||||||
Kernels::avg_pooling<<<grid, block>>>(
|
Kernels::avg_pooling<<<grid, block>>>(
|
||||||
d_input, d_output, inputSize, outputSize, nChannels, poolingSize, stride
|
d_input, d_output, inputSize, outputSize, nChannels, poolingSize, stride
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
activation.activate(d_output);
|
activation.activate(d_output);
|
||||||
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
|
||||||
return d_output;
|
return d_output;
|
||||||
}
|
}
|
||||||
@@ -23,10 +23,13 @@ float* Concat::forward(const float* d_input_A, const float* d_input_B) {
|
|||||||
CUDA_CHECK(cudaMemcpy(
|
CUDA_CHECK(cudaMemcpy(
|
||||||
d_output, d_input_A, sizeof(float) * inputASize, cudaMemcpyDeviceToDevice
|
d_output, d_input_A, sizeof(float) * inputASize, cudaMemcpyDeviceToDevice
|
||||||
));
|
));
|
||||||
|
|
||||||
CUDA_CHECK(cudaMemcpy(
|
CUDA_CHECK(cudaMemcpy(
|
||||||
d_output + inputASize, d_input_B,
|
d_output + inputASize, d_input_B,
|
||||||
sizeof(float) * inputBSize, cudaMemcpyDeviceToDevice
|
sizeof(float) * inputBSize, cudaMemcpyDeviceToDevice
|
||||||
));
|
));
|
||||||
|
|
||||||
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
|
||||||
return d_output;
|
return d_output;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,11 @@
|
|||||||
#include "convolution.cuh"
|
#include "convolution.cuh"
|
||||||
#include "cuda_helper.cuh"
|
#include "cuda_helper.cuh"
|
||||||
#include "matmul.cuh"
|
#include "matmul.cuh"
|
||||||
|
#include "layer.cuh"
|
||||||
|
#include "vector.cuh"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
using namespace CUDANet::Layers;
|
using namespace CUDANet::Layers;
|
||||||
|
|
||||||
@@ -100,6 +105,7 @@ void Conv2d::toCuda() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
float* Conv2d::forward(const float* d_input) {
|
float* Conv2d::forward(const float* d_input) {
|
||||||
|
|
||||||
// Convolve
|
// Convolve
|
||||||
dim3 block(8,8,8);
|
dim3 block(8,8,8);
|
||||||
dim3 grid(
|
dim3 grid(
|
||||||
@@ -108,10 +114,13 @@ float* Conv2d::forward(const float* d_input) {
|
|||||||
(numFilters + block.z - 1) / block.z
|
(numFilters + block.z - 1) / block.z
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CUDANet::Utils::clear(d_output, outputSize * outputSize * numFilters);
|
||||||
|
|
||||||
Kernels::convolution<<<grid, block>>>(
|
Kernels::convolution<<<grid, block>>>(
|
||||||
d_input, d_weights, d_biases, d_output, inputSize, inputChannels, paddingSize,
|
d_input, d_weights, d_biases, d_output, inputSize, inputChannels, paddingSize,
|
||||||
kernelSize, stride, numFilters, outputSize
|
kernelSize, stride, numFilters, outputSize
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
// Apply activation
|
// Apply activation
|
||||||
activation.activate(d_output);
|
activation.activate(d_output);
|
||||||
|
|||||||
@@ -65,18 +65,23 @@ void Dense::initializeBiases() {
|
|||||||
|
|
||||||
float* Dense::forward(const float* d_input) {
|
float* Dense::forward(const float* d_input) {
|
||||||
|
|
||||||
CUDANet::Utils::clear(d_output, outputSize);
|
// CUDANet::Utils::clear(d_output, outputSize);
|
||||||
|
|
||||||
|
// CUDA_CHECK(cudaPeekAtLastError());
|
||||||
|
|
||||||
|
std::cout << "Dense::forward" << std::endl;
|
||||||
|
|
||||||
Kernels::mat_vec_mul<<<forwardGridSize, BLOCK_SIZE>>>(
|
Kernels::mat_vec_mul<<<forwardGridSize, BLOCK_SIZE>>>(
|
||||||
d_weights, d_input, d_output, inputSize, outputSize
|
d_weights, d_input, d_output, inputSize, outputSize
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK(cudaPeekAtLastError());
|
||||||
|
|
||||||
Kernels::vec_vec_add<<<biasGridSize, BLOCK_SIZE>>>(
|
Kernels::vec_vec_add<<<biasGridSize, BLOCK_SIZE>>>(
|
||||||
d_biases, d_output, d_output, outputSize
|
d_biases, d_output, d_output, outputSize
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK(cudaPeekAtLastError());
|
||||||
|
|
||||||
activation.activate(d_output);
|
activation.activate(d_output);
|
||||||
|
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
|
||||||
return d_output;
|
return d_output;
|
||||||
|
|||||||
@@ -12,17 +12,11 @@ Input::~Input() {
|
|||||||
cudaFree(d_output);
|
cudaFree(d_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
Copies host input to device d_output
|
|
||||||
|
|
||||||
Args
|
|
||||||
const float* input Host pointer to input data
|
|
||||||
float* d_output Device pointer to input data copied to device
|
|
||||||
*/
|
|
||||||
float* Input::forward(const float* input) {
|
float* Input::forward(const float* input) {
|
||||||
CUDA_CHECK(cudaMemcpy(
|
CUDA_CHECK(cudaMemcpy(
|
||||||
d_output, input, sizeof(float) * inputSize, cudaMemcpyHostToDevice
|
d_output, input, sizeof(float) * inputSize, cudaMemcpyHostToDevice
|
||||||
));
|
));
|
||||||
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
|
||||||
return d_output;
|
return d_output;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,8 +45,10 @@ float* MaxPooling2D::forward(const float* d_input) {
|
|||||||
Kernels::max_pooling<<<grid, block>>>(
|
Kernels::max_pooling<<<grid, block>>>(
|
||||||
d_input, d_output, inputSize, outputSize, nChannels, poolingSize, stride
|
d_input, d_output, inputSize, outputSize, nChannels, poolingSize, stride
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
activation.activate(d_output);
|
activation.activate(d_output);
|
||||||
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
|
||||||
return d_output;
|
return d_output;
|
||||||
}
|
}
|
||||||
@@ -17,6 +17,7 @@ float* Output::forward(const float* input) {
|
|||||||
CUDA_CHECK(cudaMemcpy(
|
CUDA_CHECK(cudaMemcpy(
|
||||||
h_output, input, sizeof(float) * inputSize, cudaMemcpyDeviceToHost
|
h_output, input, sizeof(float) * inputSize, cudaMemcpyDeviceToHost
|
||||||
));
|
));
|
||||||
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
|
||||||
return h_output;
|
return h_output;
|
||||||
}
|
}
|
||||||
@@ -29,11 +29,14 @@ void Utils::max(float* d_vec, float* d_max, const unsigned int length) {
|
|||||||
const int grid_size = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
const int grid_size = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
|
|
||||||
Kernels::max_reduce<<<grid_size, BLOCK_SIZE>>>(d_vec, d_max, length);
|
Kernels::max_reduce<<<grid_size, BLOCK_SIZE>>>(d_vec, d_max, length);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
int remaining = grid_size;
|
int remaining = grid_size;
|
||||||
while (remaining > 1) {
|
while (remaining > 1) {
|
||||||
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
CUDANet::Kernels::max_reduce<<<blocks_needed, BLOCK_SIZE>>>(d_max, d_max, remaining);
|
CUDANet::Kernels::max_reduce<<<blocks_needed, BLOCK_SIZE>>>(d_max, d_max, remaining);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
remaining = blocks_needed;
|
remaining = blocks_needed;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -47,11 +50,14 @@ void Utils::sum(float* d_vec, float* d_sum, const unsigned int length) {
|
|||||||
CUDANet::Kernels::sum_reduce<<<gridSize, BLOCK_SIZE>>>(
|
CUDANet::Kernels::sum_reduce<<<gridSize, BLOCK_SIZE>>>(
|
||||||
d_vec, d_sum, length
|
d_vec, d_sum, length
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
int remaining = gridSize;
|
int remaining = gridSize;
|
||||||
while (remaining > 1) {
|
while (remaining > 1) {
|
||||||
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
CUDANet::Kernels::sum_reduce<<<blocks_needed, BLOCK_SIZE>>>(d_sum, d_sum, remaining);
|
CUDANet::Kernels::sum_reduce<<<blocks_needed, BLOCK_SIZE>>>(d_sum, d_sum, remaining);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
remaining = blocks_needed;
|
remaining = blocks_needed;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user