Combine padding and conv kernel

This commit is contained in:
2024-03-18 19:53:40 +01:00
parent e6d3757312
commit 6cf604423a
6 changed files with 14 additions and 209 deletions

View File

@@ -3,25 +3,6 @@
namespace CUDANet::Kernels { namespace CUDANet::Kernels {
/**
* @brief Kernel that pads the input matrix with zeros
*
* @param d_input Device pointer to the input matrix (as vector)
* @param d_padded Device pointer to the padded matrix (as vector)
* @param w Width of the input matrix
* @param h Height of the input matrix
* @param n Number of input channels
* @param p Padding size
*/
__global__ void padding(
const float* __restrict__ d_input,
float* __restrict__ d_padded,
const unsigned int w,
const unsigned int h,
const unsigned int n,
const unsigned int p
);
/** /**
* @brief Convolution kernel * @brief Convolution kernel
* *

View File

@@ -105,7 +105,6 @@ class Conv2d : public ILayer {
float* d_output; float* d_output;
float* d_weights; float* d_weights;
float* d_biases; float* d_biases;
float* d_padded;
// Kernels // Kernels
Layers::Activation activation; Layers::Activation activation;

View File

@@ -2,83 +2,6 @@
#include "convolution.cuh" #include "convolution.cuh"
/*
Pads matrix width x height x n_channels to width + 2 * padding x height + 2 *
padding x n_channels Matrix is represented as a pointer to a vector
For example:
w = 2
h = 3
n = 2
p = 1
Channel 0:
0 1
2 3
4 5
Channel 1:
6 7
8 9
10 11
Is represented as:
0 1 2 3 4 5 6 7 8 9 10 11
Padded result (as a continuous vector):
0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 1.0f, 0.0f,
0.0f, 2.0f, 3.0f, 0.0f,
0.0f, 4.0f, 5.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 6.0f, 7.0f, 0.0f,
0.0f, 8.0f, 9.0f, 0.0f,
9.0f, 10.0f, 11.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f
Args:
d_input: Pointer to input vector representing matrix
d_padded: Pointer to output vector representing padded matrix (needs to be
pre-allocated)
w: Width of input matrix
h: Height of input matrix
n: Number of channels in input matrix
p: Padding
*/
__global__ void CUDANet::Kernels::padding(
const float* __restrict__ d_input,
float* __restrict__ d_padded,
const unsigned int w,
const unsigned int h,
const unsigned int n,
const unsigned int p
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= (w + 2 * p) * (h + 2 * p) * n) {
return;
}
int idx = tid;
// unravel index into padded matrix
int i_n = idx / ((w + 2 * p) * (h + 2 * p));
int i_h = idx % ((w + 2 * p) * (h + 2 * p)) / (w + 2 * p);
int i_w = idx % (w + 2 * p);
// if i is in the padding region
if (i_w < p || i_w >= (w + p) || i_h < p || i_h >= (h + p)) {
d_padded[tid] = 0.0f;
} else {
// Get index into input vector
int i_orig = i_n * w * h + (i_h - p) * w + (i_w - p);
d_padded[tid] = d_input[i_orig];
}
}
__global__ void CUDANet::Kernels::convolution( __global__ void CUDANet::Kernels::convolution(
const float* __restrict__ d_input, const float* __restrict__ d_input,
const float* __restrict__ d_kernel, const float* __restrict__ d_kernel,
@@ -108,12 +31,21 @@ __global__ void CUDANet::Kernels::convolution(
for (int c = 0; c < nChannels; c++) { for (int c = 0; c < nChannels; c++) {
for (int k = 0; k < kernelSize; k++) { for (int k = 0; k < kernelSize; k++) {
for (int l = 0; l < kernelSize; l++) { for (int l = 0; l < kernelSize; l++) {
// if i, j is in the padding region
if (i * stride + k < paddingSize ||
i * stride + k >= (inputSize + paddingSize) ||
j * stride + l < paddingSize ||
j * stride + l >= (inputSize + paddingSize)) {
continue;
}
int kernelIndex = f * kernelSize * kernelSize * nChannels + int kernelIndex = f * kernelSize * kernelSize * nChannels +
c * kernelSize * kernelSize + k * kernelSize + c * kernelSize * kernelSize + k * kernelSize +
l; l;
int inputIndex = c * inputSize * inputSize + int inputIndex = c * inputSize * inputSize +
(i * stride + k) * inputSize + (i * stride + k - paddingSize) * inputSize +
(j * stride + l); (j * stride + l - paddingSize);
sum += d_kernel[kernelIndex] * d_input[inputIndex]; sum += d_kernel[kernelIndex] * d_input[inputIndex];
} }

View File

@@ -23,6 +23,7 @@ Layers::Conv2d::Conv2d(
kernelSize(kernelSize), kernelSize(kernelSize),
stride(stride), stride(stride),
numFilters(numFilters) { numFilters(numFilters) {
switch (padding) { switch (padding) {
case SAME: case SAME:
outputSize = inputSize; outputSize = inputSize;
@@ -64,12 +65,6 @@ Layers::Conv2d::Conv2d(
(void**)&d_biases, sizeof(float) * outputSize * outputSize * numFilters (void**)&d_biases, sizeof(float) * outputSize * outputSize * numFilters
)); ));
d_padded = nullptr;
CUDA_CHECK(cudaMalloc(
(void**)&d_padded, sizeof(float) * (inputSize + 2 * paddingSize) *
(inputSize + 2 * paddingSize) * inputChannels
));
toCuda(); toCuda();
} }
@@ -77,7 +72,6 @@ Layers::Conv2d::~Conv2d() {
cudaFree(d_output); cudaFree(d_output);
cudaFree(d_weights); cudaFree(d_weights);
cudaFree(d_biases); cudaFree(d_biases);
cudaFree(d_padded);
} }
void Layers::Conv2d::initializeWeights() { void Layers::Conv2d::initializeWeights() {
@@ -113,18 +107,10 @@ void Layers::Conv2d::toCuda() {
} }
float* Layers::Conv2d::forward(const float* d_input) { float* Layers::Conv2d::forward(const float* d_input) {
// Pad input
int THREADS_PER_BLOCK = (inputSize + 2 * paddingSize) *
(inputSize + 2 * paddingSize) * inputChannels;
Kernels::padding<<<1, THREADS_PER_BLOCK>>>(
d_input, d_padded, inputSize, inputSize, inputChannels, paddingSize
);
// Convolve // Convolve
THREADS_PER_BLOCK = outputSize * outputSize * numFilters; int THREADS_PER_BLOCK = outputSize * outputSize * numFilters;
Kernels::convolution<<<1, THREADS_PER_BLOCK>>>( Kernels::convolution<<<1, THREADS_PER_BLOCK>>>(
d_padded, d_weights, d_output, inputSize + 2 * paddingSize, inputChannels, paddingSize, d_input, d_weights, d_output, inputSize, inputChannels, paddingSize,
kernelSize, stride, numFilters, outputSize kernelSize, stride, numFilters, outputSize
); );

View File

@@ -9,7 +9,6 @@ add_executable(test_main
layers/test_dense.cu layers/test_dense.cu
layers/test_input.cu layers/test_input.cu
kernels/test_activation_functions.cu kernels/test_activation_functions.cu
kernels/test_padding.cu
kernels/test_matmul.cu kernels/test_matmul.cu
) )

View File

@@ -1,92 +0,0 @@
#include <cuda_runtime_api.h>
#include <gtest/gtest.h>
#include <iostream>
#include "convolution.cuh"
TEST(PaddingTest, SimplePaddingTest) {
cudaError_t cudaStatus;
int w = 2;
int h = 3;
int n = 2;
int p = 1;
float* d_input;
float* d_padded;
int inputSize = w * h * n;
int paddedSize = (w + 2 * p) * (h + 2 * p) * n;
cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * inputSize);
EXPECT_EQ(cudaStatus, cudaSuccess);
cudaStatus = cudaMalloc((void**)&d_padded, sizeof(float) * paddedSize);
EXPECT_EQ(cudaStatus, cudaSuccess);
/*
Matrix channel 0:
0 1
2 3
4 5
Matrix channel 1:
6 7
8 9
10 11
Represented as a vector:
0 1 2 3 4 5 6 7 8 9 10 11
*/
std::vector<float> input = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f};
cudaStatus = cudaMemcpy(
d_input, input.data(), sizeof(float) * inputSize, cudaMemcpyHostToDevice
);
EXPECT_EQ(cudaStatus, cudaSuccess);
int THREADS_PER_BLOCK = 64;
int BLOCKS = paddedSize / THREADS_PER_BLOCK + 1;
CUDANet::Kernels::padding<<<BLOCKS, THREADS_PER_BLOCK>>>(
d_input, d_padded, w, h, n, p
);
cudaStatus = cudaDeviceSynchronize();
EXPECT_EQ(cudaStatus, cudaSuccess);
// clang-format off
std::vector<float> expectedOutput = {
// channel 0
0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 1.0f, 0.0f,
0.0f, 2.0f, 3.0f, 0.0f,
0.0f, 4.0f, 5.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f,
// channel 1
0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 6.0f, 7.0f, 0.0f,
0.0f, 8.0f, 9.0f, 0.0f,
0.0f, 10.0f, 11.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f
};
// clang-format on
std::vector<float> output(paddedSize);
cudaStatus = cudaMemcpy(
output.data(), d_padded, sizeof(float) * paddedSize,
cudaMemcpyDeviceToHost
);
EXPECT_EQ(cudaStatus, cudaSuccess);
for (int i = 0; i < paddedSize; i++) {
EXPECT_NEAR(expectedOutput[i], output[i], 1e-5);
}
cudaFree(d_input);
cudaFree(d_padded);
}