diff --git a/include/layers/avg_pooling.cuh b/include/layers/avg_pooling.cuh index f992078..8de40ca 100644 --- a/include/layers/avg_pooling.cuh +++ b/include/layers/avg_pooling.cuh @@ -38,7 +38,7 @@ class AvgPooling2D : public SequentialLayer { float* d_output; - Activation activation; + Activation* activation; }; } // namespace CUDANet::Layers diff --git a/include/layers/conv2d.cuh b/include/layers/conv2d.cuh index c8133cd..af9bed9 100644 --- a/include/layers/conv2d.cuh +++ b/include/layers/conv2d.cuh @@ -120,8 +120,7 @@ class Conv2d : public WeightedLayer { float* d_weights; float* d_biases; - // Kernels - Activation activation; + Activation* activation; /** * @brief Initialize weights of the convolutional layer with zeros diff --git a/include/layers/dense.cuh b/include/layers/dense.cuh index e0e1ab8..dd8c988 100644 --- a/include/layers/dense.cuh +++ b/include/layers/dense.cuh @@ -77,7 +77,7 @@ class Dense : public WeightedLayer { std::vector weights; std::vector biases; - Layers::Activation activation; + Layers::Activation* activation; // Precompute kernel launch parameters unsigned int forwardGridSize; diff --git a/include/layers/max_pooling.cuh b/include/layers/max_pooling.cuh index f89b9e9..cb5d06e 100644 --- a/include/layers/max_pooling.cuh +++ b/include/layers/max_pooling.cuh @@ -38,7 +38,7 @@ class MaxPooling2D : public SequentialLayer { float* d_output; - Activation activation; + Activation* activation; }; } // namespace CUDANet::Layers diff --git a/src/layers/activation.cu b/src/layers/activation.cu index 0991fcf..8171c8b 100644 --- a/src/layers/activation.cu +++ b/src/layers/activation.cu @@ -77,4 +77,3 @@ void Activation::activate(float* d_input) { CUDA_CHECK(cudaDeviceSynchronize()); } - diff --git a/src/layers/avg_pooling.cu b/src/layers/avg_pooling.cu index e1cb40a..951dee5 100644 --- a/src/layers/avg_pooling.cu +++ b/src/layers/avg_pooling.cu @@ -18,7 +18,7 @@ AvgPooling2D::AvgPooling2D( outputSize = (inputSize - poolingSize) / stride + 1; activation = - Activation(activationType, outputSize * outputSize * nChannels); + new Activation(activationType, outputSize * outputSize * nChannels); d_output = nullptr; CUDA_CHECK(cudaMalloc( @@ -28,6 +28,7 @@ AvgPooling2D::AvgPooling2D( AvgPooling2D::~AvgPooling2D() { cudaFree(d_output); + delete activation; } float* AvgPooling2D::forward(const float* d_input) { @@ -44,7 +45,7 @@ float* AvgPooling2D::forward(const float* d_input) { ); CUDA_CHECK(cudaGetLastError()); - activation.activate(d_output); + activation->activate(d_output); CUDA_CHECK(cudaDeviceSynchronize()); return d_output; diff --git a/src/layers/conv2d.cu b/src/layers/conv2d.cu index dec60ab..3f1f829 100644 --- a/src/layers/conv2d.cu +++ b/src/layers/conv2d.cu @@ -29,7 +29,7 @@ Conv2d::Conv2d( outputSize = (inputSize - kernelSize + 2 * paddingSize) / stride + 1; - activation = Activation( + activation = new Activation( activationType, outputSize * outputSize * numFilters ); @@ -62,6 +62,7 @@ Conv2d::~Conv2d() { cudaFree(d_output); cudaFree(d_weights); cudaFree(d_biases); + delete activation; } void Conv2d::initializeWeights() { @@ -123,7 +124,7 @@ float* Conv2d::forward(const float* d_input) { CUDA_CHECK(cudaGetLastError()); // Apply activation - activation.activate(d_output); + activation->activate(d_output); CUDA_CHECK(cudaDeviceSynchronize()); diff --git a/src/layers/dense.cu b/src/layers/dense.cu index 80bd5b2..47310bb 100644 --- a/src/layers/dense.cu +++ b/src/layers/dense.cu @@ -45,14 +45,14 @@ Dense::Dense( (std::max(inputSize, outputSize) + BLOCK_SIZE - 1) / BLOCK_SIZE; biasGridSize = (outputSize + BLOCK_SIZE - 1) / BLOCK_SIZE; - activation = Activation(activationType, outputSize); + activation = new Activation(activationType, outputSize); } Dense::~Dense() { - // Free GPU memory cudaFree(d_output); cudaFree(d_weights); cudaFree(d_biases); + delete activation; } void Dense::initializeWeights() { @@ -75,7 +75,7 @@ float* Dense::forward(const float* d_input) { ); CUDA_CHECK(cudaGetLastError()); - activation.activate(d_output); + activation->activate(d_output); CUDA_CHECK(cudaDeviceSynchronize()); return d_output; diff --git a/src/layers/max_pooling.cu b/src/layers/max_pooling.cu index a8ed7c6..8f8e3d4 100644 --- a/src/layers/max_pooling.cu +++ b/src/layers/max_pooling.cu @@ -17,7 +17,7 @@ MaxPooling2D::MaxPooling2D( outputSize = (inputSize - 1) / stride + 1; - activation = Activation( + activation = new Activation( activationType, outputSize * outputSize * nChannels ); @@ -30,6 +30,7 @@ MaxPooling2D::MaxPooling2D( MaxPooling2D::~MaxPooling2D() { cudaFree(d_output); + delete activation; } @@ -47,7 +48,7 @@ float* MaxPooling2D::forward(const float* d_input) { ); CUDA_CHECK(cudaGetLastError()); - activation.activate(d_output); + activation->activate(d_output); CUDA_CHECK(cudaDeviceSynchronize()); return d_output; diff --git a/src/utils/vector.cu b/src/utils/vector.cu index e3db8bc..a8b8236 100644 --- a/src/utils/vector.cu +++ b/src/utils/vector.cu @@ -31,6 +31,7 @@ void Utils::max(float* d_vec, float* d_max, const unsigned int length) { CUDA_CHECK(cudaGetLastError()); int remaining = grid_size; + while (remaining > 1) { int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE; CUDANet::Kernels::max_reduce<<>>(d_max, d_max, remaining);