Add support for non square matrices

This commit is contained in:
2024-05-20 15:20:43 +02:00
parent 6f8b5f4081
commit 74098b24e3
21 changed files with 314 additions and 299 deletions

View File

@@ -9,19 +9,19 @@ __global__ void Kernels::convolution(
const float* __restrict__ d_kernel,
const float* __restrict__ d_bias,
float* __restrict__ d_output,
const int inputSize,
const int nChannels,
const int paddingSize,
const int kernelSize,
const int stride,
const int nFilters,
const int outputSize
const dim2d inputSize,
const int nChannels,
const dim2d paddingSize,
const dim2d kernelSize,
const dim2d stride,
const int nFilters,
const dim2d outputSize
) {
int j = blockDim.x * blockIdx.x + threadIdx.x;
int i = blockDim.y * blockIdx.y + threadIdx.y;
int f = blockDim.z * blockIdx.z + threadIdx.z;
if (i >= outputSize || j >= outputSize || f >= nFilters) {
if (i >= outputSize.first || j >= outputSize.second || f >= nFilters) {
return;
}
@@ -29,28 +29,32 @@ __global__ void Kernels::convolution(
// Iterate over kernel and input matrix
for (int c = 0; c < nChannels; c++) {
for (int k = 0; k < kernelSize; k++) {
for (int l = 0; l < kernelSize; l++) {
for (int k = 0; k < kernelSize.first; k++) {
for (int l = 0; l < kernelSize.second; l++) {
// if i, j is in the padding region
if (i * stride + k < paddingSize ||
i * stride + k >= (inputSize + paddingSize) ||
j * stride + l < paddingSize ||
j * stride + l >= (inputSize + paddingSize)) {
if (i * stride.first + k < paddingSize.first ||
i * stride.first + k >=
(inputSize.first + paddingSize.first) ||
j * stride.second + l < paddingSize.second ||
j * stride.second + l >=
(inputSize.second + paddingSize.second)) {
continue;
}
int kernelIndex = f * kernelSize * kernelSize * nChannels +
c * kernelSize * kernelSize + k * kernelSize +
l;
int inputIndex = c * inputSize * inputSize +
(i * stride + k - paddingSize) * inputSize +
(j * stride + l - paddingSize);
int kernelIndex =
f * kernelSize.first * kernelSize.second * nChannels +
c * kernelSize.first * kernelSize.second +
k * kernelSize.second + l;
int inputIndex = c * inputSize.first * inputSize.second +
(i * stride.first + k - paddingSize.first) *
inputSize.second +
(j * stride.second + l - paddingSize.second);
sum += d_kernel[kernelIndex] * d_input[inputIndex];
}
}
}
d_output[f * outputSize * outputSize + i * outputSize + j] = sum + d_bias[f];
d_output[f * outputSize.first * outputSize.second + i * outputSize.second + j] =
sum + d_bias[f];
}