Migrate conv2d layer to Tensor

This commit is contained in:
2025-11-19 20:20:46 +01:00
parent 10c84d75fc
commit dfdfa19022
10 changed files with 226 additions and 290 deletions

View File

@@ -40,6 +40,18 @@ class Backend {
const size_t input_size, const size_t input_size,
const size_t output_size const size_t output_size
) = 0; ) = 0;
virtual CUDANet::Tensor& conv2d(
const CUDANet::Tensor& weights,
const CUDANet::Tensor& biases,
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
const CUDANet::Shape in_shape,
const CUDANet::Shape padding_shape,
const CUDANet::Shape kernel_shape,
const CUDANet::Shape stride_shape,
const CUDANet::Shape out_shape
) = 0;
}; };
} // namespace CUDANet } // namespace CUDANet

View File

@@ -36,6 +36,18 @@ class CUDA : public Backend {
const size_t input_size, const size_t input_size,
const size_t output_size const size_t output_size
) override; ) override;
CUDANet::Tensor& conv2d(
const CUDANet::Tensor& weights,
const CUDANet::Tensor& biases,
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
const CUDANet::Shape in_shape,
const CUDANet::Shape padding_shape,
const CUDANet::Shape kernel_shape,
const CUDANet::Shape stride_shape,
const CUDANet::Shape out_shape
) override;
}; };
} // namespace CUDANet::Backend } // namespace CUDANet::Backend

View File

@@ -1,39 +1,20 @@
#ifndef CUDANET_CONVOLUTION_H #pragma once
#define CUDANET_CONVOLUTION_H
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "layer.hpp" #include "layer.hpp"
namespace CUDANet::Kernels { namespace CUDANet::Kernels {
/**
* @brief Convolution kernel
*
* @param d_input Device pointer to the input matrix
* @param d_kernel Device pointer to the convolution kernel
* @param d_bias Device pointer to the bias
* @param d_output Device pointer to the output matrix
* @param inputSize Width and height of the input matrix
* @param nChannels Number of channels in the input matrix
* @param kernelSize Width and height of the convolution kernel
* @param stride Convolution stride
* @param nFilters Number of output filters
* @param outputSize Width and height of the output matrix
*/
__global__ void convolution( __global__ void convolution(
const float* __restrict__ d_input, const float* __restrict__ d_input,
const float* __restrict__ d_kernel, const float* __restrict__ d_kernel,
const float* __restrict__ d_bias, const float* __restrict__ d_bias,
float* __restrict__ d_output, float* __restrict__ d_output,
const shape2d inputSize, const Shape input_shape,
const int nChannels, const Shape padding_shape,
const shape2d paddingSize, const Shape kernel_shape,
const shape2d kernelSize, const Shape stride_shape,
const shape2d stride, const Shape output_shape
const int nFilters,
const shape2d outputSize
); );
} // namespace CUDANet::Kernels } // namespace CUDANet::Kernels
#endif // CUDANET_CONVOLUTION_H

View File

@@ -1,5 +1,4 @@
#ifndef CUDANET_CONV_LAYER_H #pragma once
#define CUDANET_CONV_LAYER_H
#include <vector> #include <vector>
@@ -12,149 +11,52 @@ namespace CUDANet::Layers {
* @brief 2D convolutional layer * @brief 2D convolutional layer
* *
*/ */
class Conv2d : public WeightedLayer, public TwoDLayer { class Conv2d : public Layer {
public: public:
/**
* @brief Construct a new Conv 2d layer
*
* @param inputSize Width and height of the input matrix
* @param inputChannels Number of channels in the input matrix
* @param kernelSize Width and height of the convolution kernel
* @param stride Convolution stride
* @param numFilters Number of output filters
* @param paddingSize Padding size
* @param activationType Activation function type ('RELU', 'SIGMOID',
* 'SOFTMAX' or 'NONE')
*/
Conv2d( Conv2d(
shape2d inputSize, CUDANet::Shape input_shape,
int inputChannels, CUDANet::Shape kernel_shape,
shape2d kernelSize, CUDANet::Shape stride_shape,
shape2d stride, CUDANet::Shape padding_shape,
int numFilters, CUDANet::Backend* backend
shape2d paddingSize,
ActivationType activationType
); );
/** ~Conv2d() {};
* @brief Destroy the Conv 2d object
*
*/
~Conv2d();
/** CUDANet::Tensor& forward(const CUDANet::Tensor& input) override;
* @brief Forward pass of the convolutional layer
*
* @param d_input Device pointer to the input matrix
* @return Device pointer to the output matrix
*/
float* forward(const float* d_input);
/** CUDANet::Shape input_shape() override;
* @brief Set the weights of the convolutional layer
*
* @param weights_input Pointer to the weights
*/
void setWeights(const float* weights_input);
/** CUDANet::Shape output_shape() override;
* @brief Get the weights of the convolutional layer
*
* @return std::vector<float>
*/
std::vector<float> getWeights();
/** size_t input_size() override;
* @brief Set the biases of the convolutional layer
*
* @param biases_input Pointer to the biases
*/
void setBiases(const float* biases_input);
/** size_t output_size();
* @brief Get the biases of the convolutional layer
*
* @return std::vector<float>
*/
std::vector<float> getBiases();
/** void set_weights(void* input) override;
* @brief Get output size
*
* @return int output size
*/
int getOutputSize();
/** CUDANet::Tensor& get_weights() override;
* @brief Get input size
*
* @return int input size
*/
int getInputSize();
/** void set_biases(void* input) override;
* @brief Get the padding size of the layer
*
* @return int
*/
shape2d getPaddingSize() {
return paddingSize;
}
shape2d getOutputDims(); CUDANet::Tensor& get_biases() override;
CUDANet::Shape get_padding_shape();
private: private:
// Inputs CUDANet::Backend* backend;
shape2d inputSize;
int inputChannels;
// Outputs CUDANet::Shape in_shape;
shape2d outputSize; CUDANet::Shape out_shape;
// Kernel CUDANet::Shape kernel_shape;
shape2d kernelSize; CUDANet::Shape stride_shape;
shape2d stride; CUDANet::Shape padding_shape;
shape2d paddingSize;
int numFilters;
// Kernels CUDANet::Tensor weights;
std::vector<float> weights; CUDANet::Tensor biases;
std::vector<float> biases;
float* forwardCPU(const float* input); CUDANet::Tensor output;
// Cuda
#ifdef USE_CUDA
float* d_output;
float* d_weights;
float* d_biases;
float* forwardCUDA(const float* d_input);
void initCUDA();
void delCUDA();
/**
* @brief Copy weights and biases to the device
*
*/
void toCuda();
#endif
Activation* activation;
/**
* @brief Initialize weights of the convolutional layer with zeros
*
*/
void initializeWeights();
/**
* @brief Initialize biases of the convolutional layer with zeros
*
*/
void initializeBiases();
}; };
} // namespace CUDANet::Layers } // namespace CUDANet::Layers
#endif // CUDANET_CONV_LAYER_H

View File

@@ -14,7 +14,7 @@ namespace CUDANet::Layers {
class Dense : public Layer { class Dense : public Layer {
public: public:
Dense(CUDANet::Backend *backend, CUDANet::Shape input_shape, CUDANet::Shape output_shape); Dense(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::Backend *backend);
~Dense(); ~Dense();

View File

@@ -9,52 +9,50 @@ __global__ void Kernels::convolution(
const float* __restrict__ d_kernel, const float* __restrict__ d_kernel,
const float* __restrict__ d_bias, const float* __restrict__ d_bias,
float* __restrict__ d_output, float* __restrict__ d_output,
const shape2d inputSize, const Shape input_shape,
const int nChannels, const Shape padding_shape,
const shape2d paddingSize, const Shape kernel_shape,
const shape2d kernelSize, const Shape stride_shape,
const shape2d stride, const Shape output_shape
const int nFilters,
const shape2d outputSize
) { ) {
int j = blockDim.x * blockIdx.x + threadIdx.x; int j = blockDim.x * blockIdx.x + threadIdx.x;
int i = blockDim.y * blockIdx.y + threadIdx.y; int i = blockDim.y * blockIdx.y + threadIdx.y;
int f = blockDim.z * blockIdx.z + threadIdx.z; int f = blockDim.z * blockIdx.z + threadIdx.z;
if (i >= outputSize.first || j >= outputSize.second || f >= nFilters) { if (i >= output_shape[0] || j >= output_shape[1] || f >= output_shape[2]) {
return; return;
} }
float sum = 0.0f; float sum = 0.0f;
// Iterate over kernel and input matrix // Iterate over kernel and input matrix
for (int c = 0; c < nChannels; c++) { for (int c = 0; c < input_shape[2]; c++) {
for (int k = 0; k < kernelSize.first; k++) { for (int k = 0; k < kernel_shape[0]; k++) {
for (int l = 0; l < kernelSize.second; l++) { for (int l = 0; l < kernel_shape[1]; l++) {
// if i, j is in the padding region // if i, j is in the padding region
if (i * stride.first + k < paddingSize.first || if (i * stride_shape[0] + k < padding_shape[0] ||
i * stride.first + k >= i * stride_shape[0] + k >=
(inputSize.first + paddingSize.first) || (input_shape[0] + padding_shape[0]) ||
j * stride.second + l < paddingSize.second || j * stride_shape[1] + l < padding_shape[1] ||
j * stride.second + l >= j * stride_shape[1] + l >=
(inputSize.second + paddingSize.second)) { (input_shape[1] + padding_shape[1])) {
continue; continue;
} }
int kernelIndex = int kernelIndex =
f * kernelSize.first * kernelSize.second * nChannels + f * kernel_shape[0] * kernel_shape[1] * input_shape[2] +
c * kernelSize.first * kernelSize.second + c * kernel_shape[0] * kernel_shape[1] +
k * kernelSize.second + l; k * kernel_shape[1] + l;
int inputIndex = c * inputSize.first * inputSize.second + int inputIndex = c * input_shape[0] * input_shape[1] +
(i * stride.first + k - paddingSize.first) * (i * stride_shape[0] + k - padding_shape[0]) *
inputSize.second + input_shape[1] +
(j * stride.second + l - paddingSize.second); (j * stride_shape[1] + l - padding_shape[1]);
sum += d_kernel[kernelIndex] * d_input[inputIndex]; sum += d_kernel[kernelIndex] * d_input[inputIndex];
} }
} }
} }
d_output[f * outputSize.first * outputSize.second + i * outputSize.second + j] = d_output[f * output_shape[0] * output_shape[1] + i * output_shape[1] + j] =
sum + d_bias[f]; sum + d_bias[f];
} }

View File

@@ -1,5 +1,6 @@
#include "backend/cuda.cuh" #include "backend/cuda.cuh"
#include "kernels/activation_functions.cuh" #include "kernels/activation_functions.cuh"
#include "kernels/convolution.cuh"
#include "kernels/matmul.cuh" #include "kernels/matmul.cuh"
#include "utils/cuda_helper.cuh" #include "utils/cuda_helper.cuh"
@@ -57,7 +58,7 @@ CUDANet::Tensor& CUDA::dense(
const CUDANet::Tensor& weights, const CUDANet::Tensor& weights,
const CUDANet::Tensor& biases, const CUDANet::Tensor& biases,
const CUDANet::Tensor& input, const CUDANet::Tensor& input,
CUDANet::Tensor& output, CUDANet::Tensor& output,
const size_t input_size, const size_t input_size,
const size_t output_size const size_t output_size
) { ) {
@@ -80,3 +81,32 @@ CUDANet::Tensor& CUDA::dense(
return output; return output;
} }
CUDANet::Tensor& CUDA::conv2d(
const CUDANet::Tensor& weights,
const CUDANet::Tensor& biases,
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
const CUDANet::Shape in_shape,
const CUDANet::Shape padding_shape,
const CUDANet::Shape kernel_shape,
const CUDANet::Shape stride_shape,
const CUDANet::Shape out_shape
) {
dim3 block(8, 8, 8);
dim3 grid(
(out_shape[0] + block.x - 1) / block.x,
(out_shape[1] + block.y - 1) / block.y,
(out_shape[3] + block.z - 1) / block.z
);
Kernels::convolution<<<grid, block>>>(
input.data<float>(), weights.data<float>(), biases.data<float>(),
output.data<float>(), in_shape, padding_shape, kernel_shape,
stride_shape, out_shape
);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
return output;
}

View File

@@ -49,25 +49,5 @@ void Conv2d::toCuda() {
float* Conv2d::forwardCUDA(const float* d_input) { float* Conv2d::forwardCUDA(const float* d_input) {
// Convolve // Convolve
dim3 block(8, 8, 8);
dim3 grid(
(outputSize.first + block.x - 1) / block.x,
(outputSize.second + block.y - 1) / block.y,
(numFilters + block.z - 1) / block.z
);
CUDANet::Utils::clear(d_output, outputSize.first * outputSize.second * numFilters);
Kernels::convolution<<<grid, block>>>(
d_input, d_weights, d_biases, d_output, inputSize, inputChannels,
paddingSize, kernelSize, stride, numFilters, outputSize
);
CUDA_CHECK(cudaGetLastError());
// Apply activation
activation->activate(d_output);
CUDA_CHECK(cudaDeviceSynchronize());
return d_output;
} }

View File

@@ -1,111 +1,136 @@
#include <stdexcept>
#include <vector>
#include "activation.hpp"
#include "conv2d.hpp" #include "conv2d.hpp"
#include <format>
#include <stdexcept>
#include "layer.hpp" #include "layer.hpp"
#include "tensor.hpp"
using namespace CUDANet::Layers; using namespace CUDANet::Layers;
Conv2d::Conv2d( Conv2d::Conv2d(
shape2d inputSize, CUDANet::Shape input_shape,
int inputChannels, CUDANet::Shape kernel_shape,
shape2d kernelSize, CUDANet::Shape stride_shape,
shape2d stride, CUDANet::Shape padding_shape,
int numFilters, CUDANet::Backend* backend
shape2d paddingSize,
ActivationType activationType
) )
: inputSize(inputSize), : in_shape(input_shape),
inputChannels(inputChannels), kernel_shape(kernel_shape),
kernelSize(kernelSize), stride_shape(stride_shape),
stride(stride), padding_shape(padding_shape),
numFilters(numFilters), backend(backend) {
paddingSize(paddingSize) { if (in_shape.size() != 3) {
outputSize = { throw std::runtime_error(
(inputSize.first - kernelSize.first + 2 * paddingSize.first) / std::format(
stride.first + "Invalid input shape. Expected 3 dims, got {}", in_shape
1, )
(inputSize.second - kernelSize.second + 2 * paddingSize.second) / );
stride.second + }
1
};
activation = new Activation( if (kernel_shape.size() != 3) {
activationType, outputSize.first * outputSize.second * numFilters throw std::runtime_error(
std::format(
"Invalid kernel shape. Expected 3 dims, got {}", kernel_shape
)
);
}
if (stride_shape.size() != 2) {
throw std::runtime_error(
std::format(
"Invalid stride shape. Expected 2 dims, got {}", stride_shape
)
);
}
if (padding_shape.size() != 2) {
throw std::runtime_error(
std::format(
"Invalid padding shape. Expected 2 dims, got {}", padding_shape
)
);
}
size_t out_h = (in_shape[0] - kernel_shape[0] + 2 * padding_shape[0]) /
stride_shape[0] +
1;
size_t out_w = (in_shape[1] - kernel_shape[1] + 2 * padding_shape[1]) /
stride_shape[1] +
1;
out_shape.resize(3);
out_shape[0] = out_h;
out_shape[1] = out_w;
out_shape[2] = kernel_shape[2];
output = CUDANet::Tensor(
Shape{out_shape[0] * out_shape[1] * out_shape[3]},
CUDANet::DType::FLOAT32, backend
); );
weights.resize( weights = CUDANet::Tensor(
kernelSize.first * kernelSize.second * inputChannels * numFilters Shape{
kernel_shape[0] * kernel_shape[1] * kernel_shape[2] * in_shape[2]
},
CUDANet::DType::FLOAT32, backend
);
biases = CUDANet::Tensor(
Shape{kernel_shape[2]}, CUDANet::DType::FLOAT32, backend
); );
initializeWeights();
biases.resize(numFilters); weights.zero();
initializeBiases(); biases.zero();
#ifdef USE_CUDA
initCUDA();
toCuda();
#endif
} }
Conv2d::~Conv2d() { Conv2d::~Conv2d() {}
#ifdef USE_CUDA
delCUDA(); CUDANet::Tensor& Conv2d::forward(const CUDANet::Tensor& input) {
#endif output.zero();
delete activation; backend->conv2d(
weights,
biases,
input,
output,
in_shape,
padding_shape,
kernel_shape,
stride_shape,
out_shape
);
return output;
} }
void Conv2d::initializeWeights() { CUDANet::Shape Conv2d::input_shape() {
std::fill(weights.begin(), weights.end(), 0.0f); return in_shape;
} }
void Conv2d::initializeBiases() { CUDANet::Shape Conv2d::output_shape() {
std::fill(biases.begin(), biases.end(), 0.0f); return out_shape;
} }
void Conv2d::setWeights(const float* weights_input) { size_t Conv2d::input_size() {
std::copy(weights_input, weights_input + weights.size(), weights.begin()); return sizeof(float) * in_shape[0] * in_shape[1] * in_shape[2];
#ifdef USE_CUDA
toCuda();
#endif
} }
std::vector<float> Conv2d::getWeights() { size_t Conv2d::output_size() {
return sizeof(float) * out_shape[0] * out_shape[1] * out_shape[2];
}
void Conv2d::set_weights(void* input) {
weights.set_data<float>(static_cast<float*>(input));
}
CUDANet::Tensor& Conv2d::get_weights() {
return weights; return weights;
} }
void Conv2d::setBiases(const float* biases_input) { void Conv2d::set_biases(void* input) {
std::copy(biases_input, biases_input + biases.size(), biases.begin()); biases.set_data<float>(static_cast<float*>(input));
#ifdef USE_CUDA
toCuda();
#endif
} }
std::vector<float> Conv2d::getBiases() { CUDANet::Tensor& Conv2d::get_biases() {
return biases; return biases;
} }
float* Conv2d::forwardCPU(const float* input) { CUDANet::Shape Conv2d::get_padding_shape() {
throw std::logic_error("Not implemented"); return padding_shape;
}
float* Conv2d::forward(const float* input) {
#ifdef USE_CUDA
return forwardCUDA(input);
#else
return forwardCPU(input);
#endif
}
int Conv2d::getOutputSize() {
return outputSize.first * outputSize.second * numFilters;
}
int Conv2d::getInputSize() {
return inputSize.first * inputSize.second * inputChannels;
}
shape2d Conv2d::getOutputDims() {
return outputSize;
} }

View File

@@ -5,34 +5,30 @@
using namespace CUDANet::Layers; using namespace CUDANet::Layers;
Dense::Dense(CUDANet::Backend* backend, CUDANet::Shape in, CUDANet::Shape out) Dense::Dense(CUDANet::Shape in, CUDANet::Shape out, CUDANet::Backend* backend)
: backend(backend), : backend(backend),
in_shape(in), in_shape(in),
out_shape(out), out_shape(out) {
weights(
CUDANet::Tensor(Shape{in[0] * out[0]}, CUDANet::DType::FLOAT32, backend)
),
biases(CUDANet::Tensor(Shape{out[0]}, CUDANet::DType::FLOAT32, backend)),
output(CUDANet::Tensor(Shape{out[0]}, CUDANet::DType::FLOAT32, backend)) {
// Allocate memory for weights and biases
if (in.size() != 1) { if (in.size() != 1) {
throw std::runtime_error( throw std::runtime_error(
std::format("Invalid shape. Expected [1], got {}", in) std::format("Invalid shape. Expected [1], got {}", in_shape)
); );
} }
if (out.size() != 1) { if (out.size() != 1) {
throw std::runtime_error( throw std::runtime_error(
std::format("Invalid shape. Expected [1], got {}", out) std::format("Invalid shape. Expected [1], got {}", out_shape)
); );
} }
auto input_len = in[0]; weights = CUDANet::Tensor(Shape{in[0] * out[0]}, CUDANet::DType::FLOAT32, backend);
auto output_len = out[0]; biases = CUDANet::Tensor(Shape{out[0]}, CUDANet::DType::FLOAT32, backend);
output = CUDANet::Tensor(Shape{out[0]}, CUDANet::DType::FLOAT32, backend);
weights.zero(); weights.zero();
biases.zero(); biases.zero();
output.zero();
} }
Dense::~Dense() {} Dense::~Dense() {}