From 4a67b708f0dc84ca3fc15b78c7a512e712c499fc Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sun, 26 May 2024 18:54:12 +0200 Subject: [PATCH] Add padding to avg pooling --- include/kernels/pooling.cuh | 3 ++- include/layers/avg_pooling.cuh | 2 ++ src/kernels/pooling.cu | 16 +++++++++++----- src/layers/avg_pooling.cu | 11 +++++++---- test/layers/test_avg_pooling.cu | 28 +++++++++++++++++++++++++++- tools/pooling_test.py | 11 +++++++++++ 6 files changed, 60 insertions(+), 11 deletions(-) diff --git a/include/kernels/pooling.cuh b/include/kernels/pooling.cuh index 4c3c34c..35d6dda 100644 --- a/include/kernels/pooling.cuh +++ b/include/kernels/pooling.cuh @@ -23,7 +23,8 @@ __global__ void avg_pooling( const dim2d outputSize, const int nChannels, const dim2d poolingSize, - const dim2d stride + const dim2d stride, + const dim2d padding ); } // namespace CUDANet::Kernels diff --git a/include/layers/avg_pooling.cuh b/include/layers/avg_pooling.cuh index b9f41ab..4fe68c4 100644 --- a/include/layers/avg_pooling.cuh +++ b/include/layers/avg_pooling.cuh @@ -13,6 +13,7 @@ class AvgPooling2d : public SequentialLayer, public TwoDLayer { int nChannels, dim2d poolingSize, dim2d stride, + dim2d padding, ActivationType activationType ); ~AvgPooling2d(); @@ -40,6 +41,7 @@ class AvgPooling2d : public SequentialLayer, public TwoDLayer { int nChannels; dim2d poolingSize; dim2d stride; + dim2d padding; dim2d outputSize; diff --git a/src/kernels/pooling.cu b/src/kernels/pooling.cu index 0429ddb..ffc520f 100644 --- a/src/kernels/pooling.cu +++ b/src/kernels/pooling.cu @@ -47,7 +47,8 @@ __global__ void Kernels::avg_pooling( const dim2d outputSize, const int nChannels, const dim2d poolingSize, - const dim2d stride + const dim2d stride, + const dim2d padding ) { int j = blockDim.x * blockIdx.x + threadIdx.x; int i = blockDim.y * blockIdx.y + threadIdx.y; @@ -61,11 +62,16 @@ __global__ void Kernels::avg_pooling( for (int k = 0; k < poolingSize.first; k++) { for (int l = 0; l < poolingSize.second; l++) { - int inputIndex = c * inputSize.first * inputSize.second + - (i * stride.first + k) * inputSize.second + - (j * stride.second + l); - sum += d_input[inputIndex]; + int inputRow = i * stride.first + k - padding.first; + int inputCol = j * stride.second + l - padding.second; + + if (inputRow >= 0 && inputRow < inputSize.first && + inputCol >= 0 && inputCol < inputSize.second) { + int inputIndex = c * inputSize.first * inputSize.second + + inputRow * inputSize.second + inputCol; + sum += d_input[inputIndex]; + } } } diff --git a/src/layers/avg_pooling.cu b/src/layers/avg_pooling.cu index be35ebd..b6c19bd 100644 --- a/src/layers/avg_pooling.cu +++ b/src/layers/avg_pooling.cu @@ -9,15 +9,17 @@ AvgPooling2d::AvgPooling2d( int nChannels, dim2d poolingSize, dim2d stride, + dim2d padding, ActivationType activationType ) : inputSize(inputSize), nChannels(nChannels), poolingSize(poolingSize), - stride(stride) { + stride(stride), + padding(padding) { outputSize = { - (inputSize.first - poolingSize.first) / stride.first + 1, - (inputSize.second - poolingSize.second) / stride.second + 1 + (inputSize.first + 2 * padding.first - poolingSize.first) / stride.first + 1, + (inputSize.second + 2 * padding.second - poolingSize.second) / stride.second + 1 }; activation = new Activation( @@ -45,7 +47,8 @@ float* AvgPooling2d::forward(const float* d_input) { ); Kernels::avg_pooling<<>>( - d_input, d_output, inputSize, outputSize, nChannels, poolingSize, stride + d_input, d_output, inputSize, outputSize, nChannels, poolingSize, + stride, padding ); CUDA_CHECK(cudaGetLastError()); diff --git a/test/layers/test_avg_pooling.cu b/test/layers/test_avg_pooling.cu index 4ccc69d..0d70d40 100644 --- a/test/layers/test_avg_pooling.cu +++ b/test/layers/test_avg_pooling.cu @@ -11,6 +11,7 @@ class AvgPoolingLayerTest : public ::testing::Test { int nChannels; dim2d poolingSize; dim2d stride; + dim2d padding; std::vector input; std::vector expected; @@ -34,7 +35,7 @@ class AvgPoolingLayerTest : public ::testing::Test { cudaError_t cudaStatus; avgPoolingLayer = new CUDANet::Layers::AvgPooling2d( - inputSize, nChannels, poolingSize, stride, + inputSize, nChannels, poolingSize, stride, padding, CUDANet::Layers::ActivationType::NONE ); @@ -75,6 +76,7 @@ TEST_F(AvgPoolingLayerTest, AvgPoolForwardTest) { nChannels = 2; poolingSize = {2, 2}; stride = {2, 2}; + padding = {0, 0}; input = { // clang-format off @@ -102,6 +104,7 @@ TEST_F(AvgPoolingLayerTest, AvgPoolForwardNonSquareInputTest) { nChannels = 2; poolingSize = {2, 2}; stride = {2, 2}; + padding = {0, 0}; input = {// Channel 0 0.573f, 0.619f, 0.732f, 0.055f, 0.123f, 0.234f, 0.243f, 0.316f, @@ -124,6 +127,7 @@ TEST_F(AvgPoolingLayerTest, AvgPoolForwardNonSquarePoolingTest) { nChannels = 2; poolingSize = {2, 3}; // Non-square pooling stride = {2, 2}; + padding = {0, 0}; input = {// Channel 0 0.573f, 0.619f, 0.732f, 0.055f, 0.243f, 0.316f, 0.573f, 0.619f, @@ -143,6 +147,7 @@ TEST_F(AvgPoolingLayerTest, AvgPoolForwardNonSquareStrideTest) { nChannels = 2; poolingSize = {2, 2}; stride = {1, 2}; // Non-square stride + padding = {0, 0}; input = {// Channel 0 0.573f, 0.619f, 0.732f, 0.055f, 0.243f, 0.316f, 0.573f, 0.619f, @@ -155,5 +160,26 @@ TEST_F(AvgPoolingLayerTest, AvgPoolForwardNonSquareStrideTest) { expected = {0.43775f, 0.49475f, 0.3315f, 0.43775f, 0.48975f, 0.339f, 0.45675f, 0.303f, 0.34075f, 0.43275f, 0.56975f, 0.57025f}; + runTest(); +} + +TEST_F(AvgPoolingLayerTest, AvgPoolForwardNonSquarePaddingTest) { + inputSize = {4, 4}; + nChannels = 2; + poolingSize = {2, 2}; + stride = {2, 2}; + padding = {1, 0}; // Non-square padding + + input = {// Channel 0 + 0.573f, 0.619f, 0.732f, 0.055f, 0.243f, 0.316f, 0.573f, 0.619f, + 0.712f, 0.055f, 0.243f, 0.316f, 0.573f, 0.619f, 0.742f, 0.055f, + // Channel 1 + 0.473f, 0.919f, 0.107f, 0.073f, 0.073f, 0.362f, 0.973f, 0.059f, + 0.473f, 0.455f, 0.283f, 0.416f, 0.532f, 0.819f, 0.732f, 0.850f + }; + + expected = {0.298f, 0.19675f, 0.3315f, 0.43775f, 0.298f, 0.19925f, + 0.348f, 0.045f, 0.34075f, 0.43275f, 0.33775f, 0.3955f}; + runTest(); } \ No newline at end of file diff --git a/tools/pooling_test.py b/tools/pooling_test.py index 8d0caf1..3e48381 100644 --- a/tools/pooling_test.py +++ b/tools/pooling_test.py @@ -103,6 +103,15 @@ def gen_avg_pool_non_square_stride_test_result(): print_cpp_vector(output) +def gen_avg_pool_non_square_padding_test_result(): + + input = _get_pool_input() + + output = torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=(1, 0))(input) + output = torch.flatten(output) + + print_cpp_vector(output) + if __name__ == "__main__": print("Generating test results...") @@ -125,3 +134,5 @@ if __name__ == "__main__": gen_avg_non_square_pool_test_result() print("Avg pool non square stride test:") gen_avg_pool_non_square_stride_test_result() + print("Avg pool non square padding test:") + gen_avg_pool_non_square_padding_test_result()