Add padding to max pooling

This commit is contained in:
2024-05-26 19:03:10 +02:00
parent 4a67b708f0
commit 94a16b4352
7 changed files with 78 additions and 19 deletions

View File

@@ -13,7 +13,8 @@ __global__ void max_pooling(
const dim2d outputSize, const dim2d outputSize,
const int nChannels, const int nChannels,
const dim2d poolingSize, const dim2d poolingSize,
const dim2d stride const dim2d stride,
const dim2d padding
); );
__global__ void avg_pooling( __global__ void avg_pooling(

View File

@@ -13,6 +13,7 @@ class MaxPooling2d : public SequentialLayer, public TwoDLayer {
int nChannels, int nChannels,
dim2d poolingSize, dim2d poolingSize,
dim2d stride, dim2d stride,
dim2d padding,
ActivationType activationType ActivationType activationType
); );
~MaxPooling2d(); ~MaxPooling2d();
@@ -40,6 +41,7 @@ class MaxPooling2d : public SequentialLayer, public TwoDLayer {
int nChannels; int nChannels;
dim2d poolingSize; dim2d poolingSize;
dim2d stride; dim2d stride;
dim2d padding;
dim2d outputSize; dim2d outputSize;

View File

@@ -11,7 +11,8 @@ __global__ void Kernels::max_pooling(
const dim2d outputSize, const dim2d outputSize,
const int nChannels, const int nChannels,
const dim2d poolingSize, const dim2d poolingSize,
const dim2d stride const dim2d stride,
const dim2d padding
) { ) {
int j = blockDim.x * blockIdx.x + threadIdx.x; int j = blockDim.x * blockIdx.x + threadIdx.x;
int i = blockDim.y * blockIdx.y + threadIdx.y; int i = blockDim.y * blockIdx.y + threadIdx.y;
@@ -25,12 +26,16 @@ __global__ void Kernels::max_pooling(
for (int k = 0; k < poolingSize.first; k++) { for (int k = 0; k < poolingSize.first; k++) {
for (int l = 0; l < poolingSize.second; l++) { for (int l = 0; l < poolingSize.second; l++) {
int inputIndex = c * inputSize.first * inputSize.second + int inputRow = i * stride.first + k - padding.first;
(i * stride.first + k) * inputSize.second + int inputCol = j * stride.second + l - padding.second;
(j * stride.second + l);
if (d_input[inputIndex] > max) { if (inputRow >= 0 && inputRow < inputSize.first && inputCol >= 0 &&
max = d_input[inputIndex]; inputCol < inputSize.second) {
int inputIndex = c * inputSize.first * inputSize.second +
inputRow * inputSize.second + inputCol;
if (d_input[inputIndex] > max) {
max = d_input[inputIndex];
}
} }
} }
} }
@@ -62,12 +67,11 @@ __global__ void Kernels::avg_pooling(
for (int k = 0; k < poolingSize.first; k++) { for (int k = 0; k < poolingSize.first; k++) {
for (int l = 0; l < poolingSize.second; l++) { for (int l = 0; l < poolingSize.second; l++) {
int inputRow = i * stride.first + k - padding.first; int inputRow = i * stride.first + k - padding.first;
int inputCol = j * stride.second + l - padding.second; int inputCol = j * stride.second + l - padding.second;
if (inputRow >= 0 && inputRow < inputSize.first && if (inputRow >= 0 && inputRow < inputSize.first && inputCol >= 0 &&
inputCol >= 0 && inputCol < inputSize.second) { inputCol < inputSize.second) {
int inputIndex = c * inputSize.first * inputSize.second + int inputIndex = c * inputSize.first * inputSize.second +
inputRow * inputSize.second + inputCol; inputRow * inputSize.second + inputCol;
sum += d_input[inputIndex]; sum += d_input[inputIndex];

View File

@@ -9,23 +9,31 @@ MaxPooling2d::MaxPooling2d(
int nChannels, int nChannels,
dim2d poolingSize, dim2d poolingSize,
dim2d stride, dim2d stride,
dim2d padding,
ActivationType activationType ActivationType activationType
) )
: inputSize(inputSize), : inputSize(inputSize),
nChannels(nChannels), nChannels(nChannels),
poolingSize(poolingSize), poolingSize(poolingSize),
stride(stride) { stride(stride),
padding(padding) {
outputSize = { outputSize = {
(inputSize.first - poolingSize.first) / stride.first + 1, (inputSize.first + 2 * padding.first - poolingSize.first) /
(inputSize.second - poolingSize.second) / stride.second + 1 stride.first +
1,
(inputSize.second + 2 * padding.second - poolingSize.second) /
stride.second +
1
}; };
activation = activation = new Activation(
new Activation(activationType, outputSize.first * outputSize.second * nChannels); activationType, outputSize.first * outputSize.second * nChannels
);
d_output = nullptr; d_output = nullptr;
CUDA_CHECK(cudaMalloc( CUDA_CHECK(cudaMalloc(
(void**)&d_output, sizeof(float) * outputSize.first * outputSize.second * nChannels (void**)&d_output,
sizeof(float) * outputSize.first * outputSize.second * nChannels
)); ));
} }
@@ -43,7 +51,8 @@ float* MaxPooling2d::forward(const float* d_input) {
); );
Kernels::max_pooling<<<grid, block>>>( Kernels::max_pooling<<<grid, block>>>(
d_input, d_output, inputSize, outputSize, nChannels, poolingSize, stride d_input, d_output, inputSize, outputSize, nChannels, poolingSize,
stride, padding
); );
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());

View File

@@ -11,6 +11,7 @@ class MaxPoolingLayerTest : public ::testing::Test {
int nChannels; int nChannels;
dim2d poolingSize; dim2d poolingSize;
dim2d stride; dim2d stride;
dim2d padding;
std::vector<float> input; std::vector<float> input;
std::vector<float> expected; std::vector<float> expected;
@@ -35,7 +36,7 @@ class MaxPoolingLayerTest : public ::testing::Test {
cudaError_t cudaStatus; cudaError_t cudaStatus;
maxPoolingLayer = new CUDANet::Layers::MaxPooling2d( maxPoolingLayer = new CUDANet::Layers::MaxPooling2d(
inputSize, nChannels, poolingSize, stride, inputSize, nChannels, poolingSize, stride, padding,
CUDANet::Layers::ActivationType::NONE CUDANet::Layers::ActivationType::NONE
); );
@@ -71,6 +72,7 @@ TEST_F(MaxPoolingLayerTest, MaxPoolForwardTest) {
nChannels = 2; nChannels = 2;
poolingSize = {2, 2}; poolingSize = {2, 2};
stride = {2, 2}; stride = {2, 2};
padding = {0, 0};
input = { input = {
// clang-format off // clang-format off
@@ -97,6 +99,7 @@ TEST_F(MaxPoolingLayerTest, MaxPoolForwardNonSquareInputTest) {
nChannels = 2; nChannels = 2;
poolingSize = {2, 2}; poolingSize = {2, 2};
stride = {2, 2}; stride = {2, 2};
padding = {0, 0};
input = {// Channel 0 input = {// Channel 0
0.573f, 0.619f, 0.732f, 0.055f, 0.123f, 0.234f, 0.243f, 0.316f, 0.573f, 0.619f, 0.732f, 0.055f, 0.123f, 0.234f, 0.243f, 0.316f,
@@ -118,6 +121,7 @@ TEST_F(MaxPoolingLayerTest, MaxPoolForwardNonSquarePoolSizeTest) {
nChannels = 2; nChannels = 2;
poolingSize = {2, 3}; // Non-square pooling size poolingSize = {2, 3}; // Non-square pooling size
stride = {2, 2}; stride = {2, 2};
padding = {0, 0};
input = { input = {
// clang-format off // clang-format off
@@ -145,6 +149,7 @@ TEST_F(MaxPoolingLayerTest, MaxPoolForwardNonSquareStrideTest) {
nChannels = 2; nChannels = 2;
poolingSize = {2, 2}; poolingSize = {2, 2};
stride = {1, 2}; // Non-square stride stride = {1, 2}; // Non-square stride
padding = {0, 0};
input = { input = {
// clang-format off // clang-format off
@@ -165,4 +170,32 @@ TEST_F(MaxPoolingLayerTest, MaxPoolForwardNonSquareStrideTest) {
runTest(); runTest();
}
TEST_F(MaxPoolingLayerTest, MaxPoolForwardNonSquarePaddingTest) {
inputSize = {4, 4};
nChannels = 2;
poolingSize = {2, 2};
stride = {2, 2}; // Non-square stride
padding = {0, 1};
input = {
// clang-format off
// Channel 0
0.573f, 0.619f, 0.732f, 0.055f,
0.243f, 0.316f, 0.573f, 0.619f,
0.712f, 0.055f, 0.243f, 0.316f,
0.573f, 0.619f, 0.742f, 0.055f,
// Channel 1
0.473f, 0.919f, 0.107f, 0.073f,
0.073f, 0.362f, 0.973f, 0.059f,
0.473f, 0.455f, 0.283f, 0.416f,
0.532f, 0.819f, 0.732f, 0.850f
// clang-format on
};
expected = {0.573f, 0.732f, 0.619f, 0.712f, 0.742f, 0.316f, 0.473f, 0.973f, 0.073f, 0.532f, 0.819f, 0.85f};
runTest();
} }

View File

@@ -45,7 +45,7 @@ class ModelTest : public ::testing::Test {
CUDANet::Layers::MaxPooling2d *maxpool2d = CUDANet::Layers::MaxPooling2d *maxpool2d =
new CUDANet::Layers::MaxPooling2d( new CUDANet::Layers::MaxPooling2d(
poolingInput, numFilters, poolingSize, poolingInput, numFilters, poolingSize,
poolingStride, CUDANet::Layers::ActivationType::RELU poolingStride, {0, 0}, CUDANet::Layers::ActivationType::RELU
); );
model->addLayer("maxpool1", maxpool2d); model->addLayer("maxpool1", maxpool2d);

View File

@@ -62,6 +62,14 @@ def gen_max_pool_non_square_stride_test_result():
print_cpp_vector(output) print_cpp_vector(output)
def gen_max_pool_non_square_padding_test_result():
input = _get_pool_input()
output = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=(0, 1))(input)
output = torch.flatten(output)
print_cpp_vector(output)
def gen_avg_pool_test_result(): def gen_avg_pool_test_result():
@@ -123,6 +131,8 @@ if __name__ == "__main__":
gen_max_non_square_pool_test_result() gen_max_non_square_pool_test_result()
print("Max pool non square stride test:") print("Max pool non square stride test:")
gen_max_pool_non_square_stride_test_result() gen_max_pool_non_square_stride_test_result()
print("Max pool non square padding test:")
gen_max_pool_non_square_padding_test_result()
print("--------------") print("--------------")