diff --git a/include/backend.hpp b/include/backend.hpp index 07045dd..d877ad8 100644 --- a/include/backend.hpp +++ b/include/backend.hpp @@ -40,6 +40,18 @@ class Backend { const size_t input_size, const size_t output_size ) = 0; + + virtual CUDANet::Tensor& conv2d( + const CUDANet::Tensor& weights, + const CUDANet::Tensor& biases, + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + const CUDANet::Shape in_shape, + const CUDANet::Shape padding_shape, + const CUDANet::Shape kernel_shape, + const CUDANet::Shape stride_shape, + const CUDANet::Shape out_shape + ) = 0; }; } // namespace CUDANet \ No newline at end of file diff --git a/include/backend/cuda.cuh b/include/backend/cuda.cuh index 4489e3d..0ee5247 100644 --- a/include/backend/cuda.cuh +++ b/include/backend/cuda.cuh @@ -36,6 +36,18 @@ class CUDA : public Backend { const size_t input_size, const size_t output_size ) override; + + CUDANet::Tensor& conv2d( + const CUDANet::Tensor& weights, + const CUDANet::Tensor& biases, + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + const CUDANet::Shape in_shape, + const CUDANet::Shape padding_shape, + const CUDANet::Shape kernel_shape, + const CUDANet::Shape stride_shape, + const CUDANet::Shape out_shape + ) override; }; } // namespace CUDANet::Backend \ No newline at end of file diff --git a/include/kernels/convolution.cuh b/include/kernels/convolution.cuh index e368864..96d5ace 100644 --- a/include/kernels/convolution.cuh +++ b/include/kernels/convolution.cuh @@ -1,39 +1,20 @@ -#ifndef CUDANET_CONVOLUTION_H -#define CUDANET_CONVOLUTION_H +#pragma once #include #include "layer.hpp" namespace CUDANet::Kernels { -/** - * @brief Convolution kernel - * - * @param d_input Device pointer to the input matrix - * @param d_kernel Device pointer to the convolution kernel - * @param d_bias Device pointer to the bias - * @param d_output Device pointer to the output matrix - * @param inputSize Width and height of the input matrix - * @param nChannels Number of channels in the input matrix - * @param kernelSize Width and height of the convolution kernel - * @param stride Convolution stride - * @param nFilters Number of output filters - * @param outputSize Width and height of the output matrix - */ __global__ void convolution( const float* __restrict__ d_input, const float* __restrict__ d_kernel, const float* __restrict__ d_bias, float* __restrict__ d_output, - const shape2d inputSize, - const int nChannels, - const shape2d paddingSize, - const shape2d kernelSize, - const shape2d stride, - const int nFilters, - const shape2d outputSize + const Shape input_shape, + const Shape padding_shape, + const Shape kernel_shape, + const Shape stride_shape, + const Shape output_shape ); } // namespace CUDANet::Kernels - -#endif // CUDANET_CONVOLUTION_H \ No newline at end of file diff --git a/include/layers/conv2d.hpp b/include/layers/conv2d.hpp index 7bfe2b7..6c2f6ba 100644 --- a/include/layers/conv2d.hpp +++ b/include/layers/conv2d.hpp @@ -1,5 +1,4 @@ -#ifndef CUDANET_CONV_LAYER_H -#define CUDANET_CONV_LAYER_H +#pragma once #include @@ -12,149 +11,52 @@ namespace CUDANet::Layers { * @brief 2D convolutional layer * */ -class Conv2d : public WeightedLayer, public TwoDLayer { +class Conv2d : public Layer { public: - /** - * @brief Construct a new Conv 2d layer - * - * @param inputSize Width and height of the input matrix - * @param inputChannels Number of channels in the input matrix - * @param kernelSize Width and height of the convolution kernel - * @param stride Convolution stride - * @param numFilters Number of output filters - * @param paddingSize Padding size - * @param activationType Activation function type ('RELU', 'SIGMOID', - * 'SOFTMAX' or 'NONE') - */ Conv2d( - shape2d inputSize, - int inputChannels, - shape2d kernelSize, - shape2d stride, - int numFilters, - shape2d paddingSize, - ActivationType activationType + CUDANet::Shape input_shape, + CUDANet::Shape kernel_shape, + CUDANet::Shape stride_shape, + CUDANet::Shape padding_shape, + CUDANet::Backend* backend ); - /** - * @brief Destroy the Conv 2d object - * - */ - ~Conv2d(); + ~Conv2d() {}; - /** - * @brief Forward pass of the convolutional layer - * - * @param d_input Device pointer to the input matrix - * @return Device pointer to the output matrix - */ - float* forward(const float* d_input); + CUDANet::Tensor& forward(const CUDANet::Tensor& input) override; - /** - * @brief Set the weights of the convolutional layer - * - * @param weights_input Pointer to the weights - */ - void setWeights(const float* weights_input); + CUDANet::Shape input_shape() override; - /** - * @brief Get the weights of the convolutional layer - * - * @return std::vector - */ - std::vector getWeights(); + CUDANet::Shape output_shape() override; - /** - * @brief Set the biases of the convolutional layer - * - * @param biases_input Pointer to the biases - */ - void setBiases(const float* biases_input); + size_t input_size() override; - /** - * @brief Get the biases of the convolutional layer - * - * @return std::vector - */ - std::vector getBiases(); + size_t output_size(); - /** - * @brief Get output size - * - * @return int output size - */ - int getOutputSize(); + void set_weights(void* input) override; - /** - * @brief Get input size - * - * @return int input size - */ - int getInputSize(); + CUDANet::Tensor& get_weights() override; - /** - * @brief Get the padding size of the layer - * - * @return int - */ - shape2d getPaddingSize() { - return paddingSize; - } + void set_biases(void* input) override; - shape2d getOutputDims(); + CUDANet::Tensor& get_biases() override; + + CUDANet::Shape get_padding_shape(); private: - // Inputs - shape2d inputSize; - int inputChannels; + CUDANet::Backend* backend; - // Outputs - shape2d outputSize; + CUDANet::Shape in_shape; + CUDANet::Shape out_shape; - // Kernel - shape2d kernelSize; - shape2d stride; - shape2d paddingSize; - int numFilters; + CUDANet::Shape kernel_shape; + CUDANet::Shape stride_shape; + CUDANet::Shape padding_shape; - // Kernels - std::vector weights; - std::vector biases; + CUDANet::Tensor weights; + CUDANet::Tensor biases; - float* forwardCPU(const float* input); - -// Cuda -#ifdef USE_CUDA - float* d_output; - float* d_weights; - float* d_biases; - - float* forwardCUDA(const float* d_input); - void initCUDA(); - void delCUDA(); - - /** - * @brief Copy weights and biases to the device - * - */ - void toCuda(); -#endif - - Activation* activation; - - /** - * @brief Initialize weights of the convolutional layer with zeros - * - */ - void initializeWeights(); - - /** - * @brief Initialize biases of the convolutional layer with zeros - * - */ - void initializeBiases(); + CUDANet::Tensor output; }; } // namespace CUDANet::Layers - -#endif // CUDANET_CONV_LAYER_H diff --git a/include/layers/dense.hpp b/include/layers/dense.hpp index a74a4ab..5046d8f 100644 --- a/include/layers/dense.hpp +++ b/include/layers/dense.hpp @@ -14,7 +14,7 @@ namespace CUDANet::Layers { class Dense : public Layer { public: - Dense(CUDANet::Backend *backend, CUDANet::Shape input_shape, CUDANet::Shape output_shape); + Dense(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::Backend *backend); ~Dense(); diff --git a/src/backends/cuda/kernels/convolution.cu b/src/backends/cuda/kernels/convolution.cu index bae9729..4472325 100644 --- a/src/backends/cuda/kernels/convolution.cu +++ b/src/backends/cuda/kernels/convolution.cu @@ -9,52 +9,50 @@ __global__ void Kernels::convolution( const float* __restrict__ d_kernel, const float* __restrict__ d_bias, float* __restrict__ d_output, - const shape2d inputSize, - const int nChannels, - const shape2d paddingSize, - const shape2d kernelSize, - const shape2d stride, - const int nFilters, - const shape2d outputSize + const Shape input_shape, + const Shape padding_shape, + const Shape kernel_shape, + const Shape stride_shape, + const Shape output_shape ) { int j = blockDim.x * blockIdx.x + threadIdx.x; int i = blockDim.y * blockIdx.y + threadIdx.y; int f = blockDim.z * blockIdx.z + threadIdx.z; - if (i >= outputSize.first || j >= outputSize.second || f >= nFilters) { + if (i >= output_shape[0] || j >= output_shape[1] || f >= output_shape[2]) { return; } float sum = 0.0f; // Iterate over kernel and input matrix - for (int c = 0; c < nChannels; c++) { - for (int k = 0; k < kernelSize.first; k++) { - for (int l = 0; l < kernelSize.second; l++) { + for (int c = 0; c < input_shape[2]; c++) { + for (int k = 0; k < kernel_shape[0]; k++) { + for (int l = 0; l < kernel_shape[1]; l++) { // if i, j is in the padding region - if (i * stride.first + k < paddingSize.first || - i * stride.first + k >= - (inputSize.first + paddingSize.first) || - j * stride.second + l < paddingSize.second || - j * stride.second + l >= - (inputSize.second + paddingSize.second)) { + if (i * stride_shape[0] + k < padding_shape[0] || + i * stride_shape[0] + k >= + (input_shape[0] + padding_shape[0]) || + j * stride_shape[1] + l < padding_shape[1] || + j * stride_shape[1] + l >= + (input_shape[1] + padding_shape[1])) { continue; } int kernelIndex = - f * kernelSize.first * kernelSize.second * nChannels + - c * kernelSize.first * kernelSize.second + - k * kernelSize.second + l; - int inputIndex = c * inputSize.first * inputSize.second + - (i * stride.first + k - paddingSize.first) * - inputSize.second + - (j * stride.second + l - paddingSize.second); + f * kernel_shape[0] * kernel_shape[1] * input_shape[2] + + c * kernel_shape[0] * kernel_shape[1] + + k * kernel_shape[1] + l; + int inputIndex = c * input_shape[0] * input_shape[1] + + (i * stride_shape[0] + k - padding_shape[0]) * + input_shape[1] + + (j * stride_shape[1] + l - padding_shape[1]); sum += d_kernel[kernelIndex] * d_input[inputIndex]; } } } - d_output[f * outputSize.first * outputSize.second + i * outputSize.second + j] = + d_output[f * output_shape[0] * output_shape[1] + i * output_shape[1] + j] = sum + d_bias[f]; } \ No newline at end of file diff --git a/src/backends/cuda/layer_ops.cu b/src/backends/cuda/layer_ops.cu index 9d4fc0e..bcfe0fb 100644 --- a/src/backends/cuda/layer_ops.cu +++ b/src/backends/cuda/layer_ops.cu @@ -1,5 +1,6 @@ #include "backend/cuda.cuh" #include "kernels/activation_functions.cuh" +#include "kernels/convolution.cuh" #include "kernels/matmul.cuh" #include "utils/cuda_helper.cuh" @@ -57,7 +58,7 @@ CUDANet::Tensor& CUDA::dense( const CUDANet::Tensor& weights, const CUDANet::Tensor& biases, const CUDANet::Tensor& input, - CUDANet::Tensor& output, + CUDANet::Tensor& output, const size_t input_size, const size_t output_size ) { @@ -78,5 +79,34 @@ CUDANet::Tensor& CUDA::dense( CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaDeviceSynchronize()); + return output; +} + +CUDANet::Tensor& CUDA::conv2d( + const CUDANet::Tensor& weights, + const CUDANet::Tensor& biases, + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + const CUDANet::Shape in_shape, + const CUDANet::Shape padding_shape, + const CUDANet::Shape kernel_shape, + const CUDANet::Shape stride_shape, + const CUDANet::Shape out_shape +) { + dim3 block(8, 8, 8); + dim3 grid( + (out_shape[0] + block.x - 1) / block.x, + (out_shape[1] + block.y - 1) / block.y, + (out_shape[3] + block.z - 1) / block.z + ); + + Kernels::convolution<<>>( + input.data(), weights.data(), biases.data(), + output.data(), in_shape, padding_shape, kernel_shape, + stride_shape, out_shape + ); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + return output; } \ No newline at end of file diff --git a/src/backends/cuda/layers/conv2d.cu b/src/backends/cuda/layers/conv2d.cu index 2a6bc41..ba66fd7 100644 --- a/src/backends/cuda/layers/conv2d.cu +++ b/src/backends/cuda/layers/conv2d.cu @@ -49,25 +49,5 @@ void Conv2d::toCuda() { float* Conv2d::forwardCUDA(const float* d_input) { // Convolve - dim3 block(8, 8, 8); - dim3 grid( - (outputSize.first + block.x - 1) / block.x, - (outputSize.second + block.y - 1) / block.y, - (numFilters + block.z - 1) / block.z - ); - CUDANet::Utils::clear(d_output, outputSize.first * outputSize.second * numFilters); - - Kernels::convolution<<>>( - d_input, d_weights, d_biases, d_output, inputSize, inputChannels, - paddingSize, kernelSize, stride, numFilters, outputSize - ); - CUDA_CHECK(cudaGetLastError()); - - // Apply activation - activation->activate(d_output); - - CUDA_CHECK(cudaDeviceSynchronize()); - - return d_output; } diff --git a/src/layers/conv2d.cpp b/src/layers/conv2d.cpp index d1bdcc9..5172188 100644 --- a/src/layers/conv2d.cpp +++ b/src/layers/conv2d.cpp @@ -1,111 +1,136 @@ -#include -#include - -#include "activation.hpp" #include "conv2d.hpp" + +#include +#include + #include "layer.hpp" +#include "tensor.hpp" using namespace CUDANet::Layers; Conv2d::Conv2d( - shape2d inputSize, - int inputChannels, - shape2d kernelSize, - shape2d stride, - int numFilters, - shape2d paddingSize, - ActivationType activationType + CUDANet::Shape input_shape, + CUDANet::Shape kernel_shape, + CUDANet::Shape stride_shape, + CUDANet::Shape padding_shape, + CUDANet::Backend* backend ) - : inputSize(inputSize), - inputChannels(inputChannels), - kernelSize(kernelSize), - stride(stride), - numFilters(numFilters), - paddingSize(paddingSize) { - outputSize = { - (inputSize.first - kernelSize.first + 2 * paddingSize.first) / - stride.first + - 1, - (inputSize.second - kernelSize.second + 2 * paddingSize.second) / - stride.second + - 1 - }; + : in_shape(input_shape), + kernel_shape(kernel_shape), + stride_shape(stride_shape), + padding_shape(padding_shape), + backend(backend) { + if (in_shape.size() != 3) { + throw std::runtime_error( + std::format( + "Invalid input shape. Expected 3 dims, got {}", in_shape + ) + ); + } - activation = new Activation( - activationType, outputSize.first * outputSize.second * numFilters + if (kernel_shape.size() != 3) { + throw std::runtime_error( + std::format( + "Invalid kernel shape. Expected 3 dims, got {}", kernel_shape + ) + ); + } + + if (stride_shape.size() != 2) { + throw std::runtime_error( + std::format( + "Invalid stride shape. Expected 2 dims, got {}", stride_shape + ) + ); + } + + if (padding_shape.size() != 2) { + throw std::runtime_error( + std::format( + "Invalid padding shape. Expected 2 dims, got {}", padding_shape + ) + ); + } + + size_t out_h = (in_shape[0] - kernel_shape[0] + 2 * padding_shape[0]) / + stride_shape[0] + + 1; + size_t out_w = (in_shape[1] - kernel_shape[1] + 2 * padding_shape[1]) / + stride_shape[1] + + 1; + out_shape.resize(3); + out_shape[0] = out_h; + out_shape[1] = out_w; + out_shape[2] = kernel_shape[2]; + output = CUDANet::Tensor( + Shape{out_shape[0] * out_shape[1] * out_shape[3]}, + CUDANet::DType::FLOAT32, backend ); - weights.resize( - kernelSize.first * kernelSize.second * inputChannels * numFilters + weights = CUDANet::Tensor( + Shape{ + kernel_shape[0] * kernel_shape[1] * kernel_shape[2] * in_shape[2] + }, + CUDANet::DType::FLOAT32, backend + ); + biases = CUDANet::Tensor( + Shape{kernel_shape[2]}, CUDANet::DType::FLOAT32, backend ); - initializeWeights(); - biases.resize(numFilters); - initializeBiases(); - -#ifdef USE_CUDA - initCUDA(); - toCuda(); -#endif + weights.zero(); + biases.zero(); } -Conv2d::~Conv2d() { -#ifdef USE_CUDA - delCUDA(); -#endif - delete activation; +Conv2d::~Conv2d() {} + +CUDANet::Tensor& Conv2d::forward(const CUDANet::Tensor& input) { + output.zero(); + backend->conv2d( + weights, + biases, + input, + output, + in_shape, + padding_shape, + kernel_shape, + stride_shape, + out_shape + ); + return output; } -void Conv2d::initializeWeights() { - std::fill(weights.begin(), weights.end(), 0.0f); +CUDANet::Shape Conv2d::input_shape() { + return in_shape; } -void Conv2d::initializeBiases() { - std::fill(biases.begin(), biases.end(), 0.0f); +CUDANet::Shape Conv2d::output_shape() { + return out_shape; } -void Conv2d::setWeights(const float* weights_input) { - std::copy(weights_input, weights_input + weights.size(), weights.begin()); -#ifdef USE_CUDA - toCuda(); -#endif +size_t Conv2d::input_size() { + return sizeof(float) * in_shape[0] * in_shape[1] * in_shape[2]; } -std::vector Conv2d::getWeights() { +size_t Conv2d::output_size() { + return sizeof(float) * out_shape[0] * out_shape[1] * out_shape[2]; +} + +void Conv2d::set_weights(void* input) { + weights.set_data(static_cast(input)); +} + +CUDANet::Tensor& Conv2d::get_weights() { return weights; } -void Conv2d::setBiases(const float* biases_input) { - std::copy(biases_input, biases_input + biases.size(), biases.begin()); -#ifdef USE_CUDA - toCuda(); -#endif +void Conv2d::set_biases(void* input) { + biases.set_data(static_cast(input)); } -std::vector Conv2d::getBiases() { +CUDANet::Tensor& Conv2d::get_biases() { return biases; } -float* Conv2d::forwardCPU(const float* input) { - throw std::logic_error("Not implemented"); -} - -float* Conv2d::forward(const float* input) { -#ifdef USE_CUDA - return forwardCUDA(input); -#else - return forwardCPU(input); -#endif -} - -int Conv2d::getOutputSize() { - return outputSize.first * outputSize.second * numFilters; -} - -int Conv2d::getInputSize() { - return inputSize.first * inputSize.second * inputChannels; -} - -shape2d Conv2d::getOutputDims() { - return outputSize; +CUDANet::Shape Conv2d::get_padding_shape() { + return padding_shape; } \ No newline at end of file diff --git a/src/layers/dense.cpp b/src/layers/dense.cpp index 1ba8c4c..d0aa6a3 100644 --- a/src/layers/dense.cpp +++ b/src/layers/dense.cpp @@ -5,34 +5,30 @@ using namespace CUDANet::Layers; -Dense::Dense(CUDANet::Backend* backend, CUDANet::Shape in, CUDANet::Shape out) +Dense::Dense(CUDANet::Shape in, CUDANet::Shape out, CUDANet::Backend* backend) : backend(backend), in_shape(in), - out_shape(out), - weights( - CUDANet::Tensor(Shape{in[0] * out[0]}, CUDANet::DType::FLOAT32, backend) - ), - biases(CUDANet::Tensor(Shape{out[0]}, CUDANet::DType::FLOAT32, backend)), - output(CUDANet::Tensor(Shape{out[0]}, CUDANet::DType::FLOAT32, backend)) { - // Allocate memory for weights and biases + out_shape(out) { if (in.size() != 1) { throw std::runtime_error( - std::format("Invalid shape. Expected [1], got {}", in) + std::format("Invalid shape. Expected [1], got {}", in_shape) ); } if (out.size() != 1) { throw std::runtime_error( - std::format("Invalid shape. Expected [1], got {}", out) + std::format("Invalid shape. Expected [1], got {}", out_shape) ); } - auto input_len = in[0]; - auto output_len = out[0]; + weights = CUDANet::Tensor(Shape{in[0] * out[0]}, CUDANet::DType::FLOAT32, backend); + biases = CUDANet::Tensor(Shape{out[0]}, CUDANet::DType::FLOAT32, backend); + output = CUDANet::Tensor(Shape{out[0]}, CUDANet::DType::FLOAT32, backend); weights.zero(); biases.zero(); + output.zero(); } Dense::~Dense() {}