Add padding to avg pooling

This commit is contained in:
2024-05-26 18:54:12 +02:00
parent d1fc45d9e0
commit 4a67b708f0
6 changed files with 60 additions and 11 deletions

View File

@@ -11,6 +11,7 @@ class AvgPoolingLayerTest : public ::testing::Test {
int nChannels;
dim2d poolingSize;
dim2d stride;
dim2d padding;
std::vector<float> input;
std::vector<float> expected;
@@ -34,7 +35,7 @@ class AvgPoolingLayerTest : public ::testing::Test {
cudaError_t cudaStatus;
avgPoolingLayer = new CUDANet::Layers::AvgPooling2d(
inputSize, nChannels, poolingSize, stride,
inputSize, nChannels, poolingSize, stride, padding,
CUDANet::Layers::ActivationType::NONE
);
@@ -75,6 +76,7 @@ TEST_F(AvgPoolingLayerTest, AvgPoolForwardTest) {
nChannels = 2;
poolingSize = {2, 2};
stride = {2, 2};
padding = {0, 0};
input = {
// clang-format off
@@ -102,6 +104,7 @@ TEST_F(AvgPoolingLayerTest, AvgPoolForwardNonSquareInputTest) {
nChannels = 2;
poolingSize = {2, 2};
stride = {2, 2};
padding = {0, 0};
input = {// Channel 0
0.573f, 0.619f, 0.732f, 0.055f, 0.123f, 0.234f, 0.243f, 0.316f,
@@ -124,6 +127,7 @@ TEST_F(AvgPoolingLayerTest, AvgPoolForwardNonSquarePoolingTest) {
nChannels = 2;
poolingSize = {2, 3}; // Non-square pooling
stride = {2, 2};
padding = {0, 0};
input = {// Channel 0
0.573f, 0.619f, 0.732f, 0.055f, 0.243f, 0.316f, 0.573f, 0.619f,
@@ -143,6 +147,7 @@ TEST_F(AvgPoolingLayerTest, AvgPoolForwardNonSquareStrideTest) {
nChannels = 2;
poolingSize = {2, 2};
stride = {1, 2}; // Non-square stride
padding = {0, 0};
input = {// Channel 0
0.573f, 0.619f, 0.732f, 0.055f, 0.243f, 0.316f, 0.573f, 0.619f,
@@ -155,5 +160,26 @@ TEST_F(AvgPoolingLayerTest, AvgPoolForwardNonSquareStrideTest) {
expected = {0.43775f, 0.49475f, 0.3315f, 0.43775f, 0.48975f, 0.339f,
0.45675f, 0.303f, 0.34075f, 0.43275f, 0.56975f, 0.57025f};
runTest();
}
TEST_F(AvgPoolingLayerTest, AvgPoolForwardNonSquarePaddingTest) {
inputSize = {4, 4};
nChannels = 2;
poolingSize = {2, 2};
stride = {2, 2};
padding = {1, 0}; // Non-square padding
input = {// Channel 0
0.573f, 0.619f, 0.732f, 0.055f, 0.243f, 0.316f, 0.573f, 0.619f,
0.712f, 0.055f, 0.243f, 0.316f, 0.573f, 0.619f, 0.742f, 0.055f,
// Channel 1
0.473f, 0.919f, 0.107f, 0.073f, 0.073f, 0.362f, 0.973f, 0.059f,
0.473f, 0.455f, 0.283f, 0.416f, 0.532f, 0.819f, 0.732f, 0.850f
};
expected = {0.298f, 0.19675f, 0.3315f, 0.43775f, 0.298f, 0.19925f,
0.348f, 0.045f, 0.34075f, 0.43275f, 0.33775f, 0.3955f};
runTest();
}