Migrate conv2d layer to Tensor

This commit is contained in:
2025-11-19 20:20:46 +01:00
parent 10c84d75fc
commit dfdfa19022
10 changed files with 226 additions and 290 deletions

View File

@@ -9,52 +9,50 @@ __global__ void Kernels::convolution(
const float* __restrict__ d_kernel,
const float* __restrict__ d_bias,
float* __restrict__ d_output,
const shape2d inputSize,
const int nChannels,
const shape2d paddingSize,
const shape2d kernelSize,
const shape2d stride,
const int nFilters,
const shape2d outputSize
const Shape input_shape,
const Shape padding_shape,
const Shape kernel_shape,
const Shape stride_shape,
const Shape output_shape
) {
int j = blockDim.x * blockIdx.x + threadIdx.x;
int i = blockDim.y * blockIdx.y + threadIdx.y;
int f = blockDim.z * blockIdx.z + threadIdx.z;
if (i >= outputSize.first || j >= outputSize.second || f >= nFilters) {
if (i >= output_shape[0] || j >= output_shape[1] || f >= output_shape[2]) {
return;
}
float sum = 0.0f;
// Iterate over kernel and input matrix
for (int c = 0; c < nChannels; c++) {
for (int k = 0; k < kernelSize.first; k++) {
for (int l = 0; l < kernelSize.second; l++) {
for (int c = 0; c < input_shape[2]; c++) {
for (int k = 0; k < kernel_shape[0]; k++) {
for (int l = 0; l < kernel_shape[1]; l++) {
// if i, j is in the padding region
if (i * stride.first + k < paddingSize.first ||
i * stride.first + k >=
(inputSize.first + paddingSize.first) ||
j * stride.second + l < paddingSize.second ||
j * stride.second + l >=
(inputSize.second + paddingSize.second)) {
if (i * stride_shape[0] + k < padding_shape[0] ||
i * stride_shape[0] + k >=
(input_shape[0] + padding_shape[0]) ||
j * stride_shape[1] + l < padding_shape[1] ||
j * stride_shape[1] + l >=
(input_shape[1] + padding_shape[1])) {
continue;
}
int kernelIndex =
f * kernelSize.first * kernelSize.second * nChannels +
c * kernelSize.first * kernelSize.second +
k * kernelSize.second + l;
int inputIndex = c * inputSize.first * inputSize.second +
(i * stride.first + k - paddingSize.first) *
inputSize.second +
(j * stride.second + l - paddingSize.second);
f * kernel_shape[0] * kernel_shape[1] * input_shape[2] +
c * kernel_shape[0] * kernel_shape[1] +
k * kernel_shape[1] + l;
int inputIndex = c * input_shape[0] * input_shape[1] +
(i * stride_shape[0] + k - padding_shape[0]) *
input_shape[1] +
(j * stride_shape[1] + l - padding_shape[1]);
sum += d_kernel[kernelIndex] * d_input[inputIndex];
}
}
}
d_output[f * outputSize.first * outputSize.second + i * outputSize.second + j] =
d_output[f * output_shape[0] * output_shape[1] + i * output_shape[1] + j] =
sum + d_bias[f];
}

View File

@@ -1,5 +1,6 @@
#include "backend/cuda.cuh"
#include "kernels/activation_functions.cuh"
#include "kernels/convolution.cuh"
#include "kernels/matmul.cuh"
#include "utils/cuda_helper.cuh"
@@ -57,7 +58,7 @@ CUDANet::Tensor& CUDA::dense(
const CUDANet::Tensor& weights,
const CUDANet::Tensor& biases,
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
CUDANet::Tensor& output,
const size_t input_size,
const size_t output_size
) {
@@ -78,5 +79,34 @@ CUDANet::Tensor& CUDA::dense(
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
return output;
}
CUDANet::Tensor& CUDA::conv2d(
const CUDANet::Tensor& weights,
const CUDANet::Tensor& biases,
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
const CUDANet::Shape in_shape,
const CUDANet::Shape padding_shape,
const CUDANet::Shape kernel_shape,
const CUDANet::Shape stride_shape,
const CUDANet::Shape out_shape
) {
dim3 block(8, 8, 8);
dim3 grid(
(out_shape[0] + block.x - 1) / block.x,
(out_shape[1] + block.y - 1) / block.y,
(out_shape[3] + block.z - 1) / block.z
);
Kernels::convolution<<<grid, block>>>(
input.data<float>(), weights.data<float>(), biases.data<float>(),
output.data<float>(), in_shape, padding_shape, kernel_shape,
stride_shape, out_shape
);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
return output;
}

View File

@@ -49,25 +49,5 @@ void Conv2d::toCuda() {
float* Conv2d::forwardCUDA(const float* d_input) {
// Convolve
dim3 block(8, 8, 8);
dim3 grid(
(outputSize.first + block.x - 1) / block.x,
(outputSize.second + block.y - 1) / block.y,
(numFilters + block.z - 1) / block.z
);
CUDANet::Utils::clear(d_output, outputSize.first * outputSize.second * numFilters);
Kernels::convolution<<<grid, block>>>(
d_input, d_weights, d_biases, d_output, inputSize, inputChannels,
paddingSize, kernelSize, stride, numFilters, outputSize
);
CUDA_CHECK(cudaGetLastError());
// Apply activation
activation->activate(d_output);
CUDA_CHECK(cudaDeviceSynchronize());
return d_output;
}