mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Combine padding and conv kernel
This commit is contained in:
@@ -3,25 +3,6 @@
|
|||||||
|
|
||||||
namespace CUDANet::Kernels {
|
namespace CUDANet::Kernels {
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Kernel that pads the input matrix with zeros
|
|
||||||
*
|
|
||||||
* @param d_input Device pointer to the input matrix (as vector)
|
|
||||||
* @param d_padded Device pointer to the padded matrix (as vector)
|
|
||||||
* @param w Width of the input matrix
|
|
||||||
* @param h Height of the input matrix
|
|
||||||
* @param n Number of input channels
|
|
||||||
* @param p Padding size
|
|
||||||
*/
|
|
||||||
__global__ void padding(
|
|
||||||
const float* __restrict__ d_input,
|
|
||||||
float* __restrict__ d_padded,
|
|
||||||
const unsigned int w,
|
|
||||||
const unsigned int h,
|
|
||||||
const unsigned int n,
|
|
||||||
const unsigned int p
|
|
||||||
);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Convolution kernel
|
* @brief Convolution kernel
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -105,7 +105,6 @@ class Conv2d : public ILayer {
|
|||||||
float* d_output;
|
float* d_output;
|
||||||
float* d_weights;
|
float* d_weights;
|
||||||
float* d_biases;
|
float* d_biases;
|
||||||
float* d_padded;
|
|
||||||
|
|
||||||
// Kernels
|
// Kernels
|
||||||
Layers::Activation activation;
|
Layers::Activation activation;
|
||||||
|
|||||||
@@ -2,83 +2,6 @@
|
|||||||
|
|
||||||
#include "convolution.cuh"
|
#include "convolution.cuh"
|
||||||
|
|
||||||
/*
|
|
||||||
Pads matrix width x height x n_channels to width + 2 * padding x height + 2 *
|
|
||||||
padding x n_channels Matrix is represented as a pointer to a vector
|
|
||||||
|
|
||||||
For example:
|
|
||||||
|
|
||||||
w = 2
|
|
||||||
h = 3
|
|
||||||
n = 2
|
|
||||||
p = 1
|
|
||||||
|
|
||||||
Channel 0:
|
|
||||||
0 1
|
|
||||||
2 3
|
|
||||||
4 5
|
|
||||||
Channel 1:
|
|
||||||
6 7
|
|
||||||
8 9
|
|
||||||
10 11
|
|
||||||
|
|
||||||
Is represented as:
|
|
||||||
|
|
||||||
0 1 2 3 4 5 6 7 8 9 10 11
|
|
||||||
|
|
||||||
Padded result (as a continuous vector):
|
|
||||||
|
|
||||||
0.0f, 0.0f, 0.0f, 0.0f,
|
|
||||||
0.0f, 0.0f, 1.0f, 0.0f,
|
|
||||||
0.0f, 2.0f, 3.0f, 0.0f,
|
|
||||||
0.0f, 4.0f, 5.0f, 0.0f,
|
|
||||||
0.0f, 0.0f, 0.0f, 0.0f,
|
|
||||||
0.0f, 0.0f, 0.0f, 0.0f,
|
|
||||||
0.0f, 6.0f, 7.0f, 0.0f,
|
|
||||||
0.0f, 8.0f, 9.0f, 0.0f,
|
|
||||||
9.0f, 10.0f, 11.0f, 0.0f,
|
|
||||||
0.0f, 0.0f, 0.0f, 0.0f
|
|
||||||
|
|
||||||
Args:
|
|
||||||
d_input: Pointer to input vector representing matrix
|
|
||||||
d_padded: Pointer to output vector representing padded matrix (needs to be
|
|
||||||
pre-allocated)
|
|
||||||
w: Width of input matrix
|
|
||||||
h: Height of input matrix
|
|
||||||
n: Number of channels in input matrix
|
|
||||||
p: Padding
|
|
||||||
*/
|
|
||||||
__global__ void CUDANet::Kernels::padding(
|
|
||||||
const float* __restrict__ d_input,
|
|
||||||
float* __restrict__ d_padded,
|
|
||||||
const unsigned int w,
|
|
||||||
const unsigned int h,
|
|
||||||
const unsigned int n,
|
|
||||||
const unsigned int p
|
|
||||||
) {
|
|
||||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
|
||||||
|
|
||||||
if (tid >= (w + 2 * p) * (h + 2 * p) * n) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
int idx = tid;
|
|
||||||
|
|
||||||
// unravel index into padded matrix
|
|
||||||
int i_n = idx / ((w + 2 * p) * (h + 2 * p));
|
|
||||||
int i_h = idx % ((w + 2 * p) * (h + 2 * p)) / (w + 2 * p);
|
|
||||||
int i_w = idx % (w + 2 * p);
|
|
||||||
|
|
||||||
// if i is in the padding region
|
|
||||||
if (i_w < p || i_w >= (w + p) || i_h < p || i_h >= (h + p)) {
|
|
||||||
d_padded[tid] = 0.0f;
|
|
||||||
} else {
|
|
||||||
// Get index into input vector
|
|
||||||
int i_orig = i_n * w * h + (i_h - p) * w + (i_w - p);
|
|
||||||
d_padded[tid] = d_input[i_orig];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__global__ void CUDANet::Kernels::convolution(
|
__global__ void CUDANet::Kernels::convolution(
|
||||||
const float* __restrict__ d_input,
|
const float* __restrict__ d_input,
|
||||||
const float* __restrict__ d_kernel,
|
const float* __restrict__ d_kernel,
|
||||||
@@ -108,12 +31,21 @@ __global__ void CUDANet::Kernels::convolution(
|
|||||||
for (int c = 0; c < nChannels; c++) {
|
for (int c = 0; c < nChannels; c++) {
|
||||||
for (int k = 0; k < kernelSize; k++) {
|
for (int k = 0; k < kernelSize; k++) {
|
||||||
for (int l = 0; l < kernelSize; l++) {
|
for (int l = 0; l < kernelSize; l++) {
|
||||||
|
|
||||||
|
// if i, j is in the padding region
|
||||||
|
if (i * stride + k < paddingSize ||
|
||||||
|
i * stride + k >= (inputSize + paddingSize) ||
|
||||||
|
j * stride + l < paddingSize ||
|
||||||
|
j * stride + l >= (inputSize + paddingSize)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
int kernelIndex = f * kernelSize * kernelSize * nChannels +
|
int kernelIndex = f * kernelSize * kernelSize * nChannels +
|
||||||
c * kernelSize * kernelSize + k * kernelSize +
|
c * kernelSize * kernelSize + k * kernelSize +
|
||||||
l;
|
l;
|
||||||
int inputIndex = c * inputSize * inputSize +
|
int inputIndex = c * inputSize * inputSize +
|
||||||
(i * stride + k) * inputSize +
|
(i * stride + k - paddingSize) * inputSize +
|
||||||
(j * stride + l);
|
(j * stride + l - paddingSize);
|
||||||
|
|
||||||
sum += d_kernel[kernelIndex] * d_input[inputIndex];
|
sum += d_kernel[kernelIndex] * d_input[inputIndex];
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ Layers::Conv2d::Conv2d(
|
|||||||
kernelSize(kernelSize),
|
kernelSize(kernelSize),
|
||||||
stride(stride),
|
stride(stride),
|
||||||
numFilters(numFilters) {
|
numFilters(numFilters) {
|
||||||
|
|
||||||
switch (padding) {
|
switch (padding) {
|
||||||
case SAME:
|
case SAME:
|
||||||
outputSize = inputSize;
|
outputSize = inputSize;
|
||||||
@@ -64,12 +65,6 @@ Layers::Conv2d::Conv2d(
|
|||||||
(void**)&d_biases, sizeof(float) * outputSize * outputSize * numFilters
|
(void**)&d_biases, sizeof(float) * outputSize * outputSize * numFilters
|
||||||
));
|
));
|
||||||
|
|
||||||
d_padded = nullptr;
|
|
||||||
CUDA_CHECK(cudaMalloc(
|
|
||||||
(void**)&d_padded, sizeof(float) * (inputSize + 2 * paddingSize) *
|
|
||||||
(inputSize + 2 * paddingSize) * inputChannels
|
|
||||||
));
|
|
||||||
|
|
||||||
toCuda();
|
toCuda();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,7 +72,6 @@ Layers::Conv2d::~Conv2d() {
|
|||||||
cudaFree(d_output);
|
cudaFree(d_output);
|
||||||
cudaFree(d_weights);
|
cudaFree(d_weights);
|
||||||
cudaFree(d_biases);
|
cudaFree(d_biases);
|
||||||
cudaFree(d_padded);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Layers::Conv2d::initializeWeights() {
|
void Layers::Conv2d::initializeWeights() {
|
||||||
@@ -113,18 +107,10 @@ void Layers::Conv2d::toCuda() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
float* Layers::Conv2d::forward(const float* d_input) {
|
float* Layers::Conv2d::forward(const float* d_input) {
|
||||||
// Pad input
|
|
||||||
int THREADS_PER_BLOCK = (inputSize + 2 * paddingSize) *
|
|
||||||
(inputSize + 2 * paddingSize) * inputChannels;
|
|
||||||
|
|
||||||
Kernels::padding<<<1, THREADS_PER_BLOCK>>>(
|
|
||||||
d_input, d_padded, inputSize, inputSize, inputChannels, paddingSize
|
|
||||||
);
|
|
||||||
|
|
||||||
// Convolve
|
// Convolve
|
||||||
THREADS_PER_BLOCK = outputSize * outputSize * numFilters;
|
int THREADS_PER_BLOCK = outputSize * outputSize * numFilters;
|
||||||
Kernels::convolution<<<1, THREADS_PER_BLOCK>>>(
|
Kernels::convolution<<<1, THREADS_PER_BLOCK>>>(
|
||||||
d_padded, d_weights, d_output, inputSize + 2 * paddingSize, inputChannels, paddingSize,
|
d_input, d_weights, d_output, inputSize, inputChannels, paddingSize,
|
||||||
kernelSize, stride, numFilters, outputSize
|
kernelSize, stride, numFilters, outputSize
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ add_executable(test_main
|
|||||||
layers/test_dense.cu
|
layers/test_dense.cu
|
||||||
layers/test_input.cu
|
layers/test_input.cu
|
||||||
kernels/test_activation_functions.cu
|
kernels/test_activation_functions.cu
|
||||||
kernels/test_padding.cu
|
|
||||||
kernels/test_matmul.cu
|
kernels/test_matmul.cu
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,92 +0,0 @@
|
|||||||
#include <cuda_runtime_api.h>
|
|
||||||
#include <gtest/gtest.h>
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
#include "convolution.cuh"
|
|
||||||
|
|
||||||
TEST(PaddingTest, SimplePaddingTest) {
|
|
||||||
cudaError_t cudaStatus;
|
|
||||||
|
|
||||||
int w = 2;
|
|
||||||
int h = 3;
|
|
||||||
int n = 2;
|
|
||||||
int p = 1;
|
|
||||||
|
|
||||||
float* d_input;
|
|
||||||
float* d_padded;
|
|
||||||
|
|
||||||
int inputSize = w * h * n;
|
|
||||||
int paddedSize = (w + 2 * p) * (h + 2 * p) * n;
|
|
||||||
|
|
||||||
cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * inputSize);
|
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
|
||||||
|
|
||||||
cudaStatus = cudaMalloc((void**)&d_padded, sizeof(float) * paddedSize);
|
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
|
||||||
|
|
||||||
/*
|
|
||||||
Matrix channel 0:
|
|
||||||
0 1
|
|
||||||
2 3
|
|
||||||
4 5
|
|
||||||
Matrix channel 1:
|
|
||||||
6 7
|
|
||||||
8 9
|
|
||||||
10 11
|
|
||||||
|
|
||||||
Represented as a vector:
|
|
||||||
|
|
||||||
0 1 2 3 4 5 6 7 8 9 10 11
|
|
||||||
*/
|
|
||||||
|
|
||||||
std::vector<float> input = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
|
|
||||||
6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f};
|
|
||||||
|
|
||||||
cudaStatus = cudaMemcpy(
|
|
||||||
d_input, input.data(), sizeof(float) * inputSize, cudaMemcpyHostToDevice
|
|
||||||
);
|
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
|
||||||
|
|
||||||
int THREADS_PER_BLOCK = 64;
|
|
||||||
int BLOCKS = paddedSize / THREADS_PER_BLOCK + 1;
|
|
||||||
|
|
||||||
CUDANet::Kernels::padding<<<BLOCKS, THREADS_PER_BLOCK>>>(
|
|
||||||
d_input, d_padded, w, h, n, p
|
|
||||||
);
|
|
||||||
cudaStatus = cudaDeviceSynchronize();
|
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
std::vector<float> expectedOutput = {
|
|
||||||
// channel 0
|
|
||||||
0.0f, 0.0f, 0.0f, 0.0f,
|
|
||||||
0.0f, 0.0f, 1.0f, 0.0f,
|
|
||||||
0.0f, 2.0f, 3.0f, 0.0f,
|
|
||||||
0.0f, 4.0f, 5.0f, 0.0f,
|
|
||||||
0.0f, 0.0f, 0.0f, 0.0f,
|
|
||||||
// channel 1
|
|
||||||
0.0f, 0.0f, 0.0f, 0.0f,
|
|
||||||
0.0f, 6.0f, 7.0f, 0.0f,
|
|
||||||
0.0f, 8.0f, 9.0f, 0.0f,
|
|
||||||
0.0f, 10.0f, 11.0f, 0.0f,
|
|
||||||
0.0f, 0.0f, 0.0f, 0.0f
|
|
||||||
};
|
|
||||||
// clang-format on
|
|
||||||
|
|
||||||
std::vector<float> output(paddedSize);
|
|
||||||
|
|
||||||
cudaStatus = cudaMemcpy(
|
|
||||||
output.data(), d_padded, sizeof(float) * paddedSize,
|
|
||||||
cudaMemcpyDeviceToHost
|
|
||||||
);
|
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
|
||||||
|
|
||||||
for (int i = 0; i < paddedSize; i++) {
|
|
||||||
EXPECT_NEAR(expectedOutput[i], output[i], 1e-5);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
cudaFree(d_input);
|
|
||||||
cudaFree(d_padded);
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user