From 6f4cdf3792b8eea76668dc49009270a583473d7f Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 20 Mar 2024 21:57:22 +0100 Subject: [PATCH] Implement avg pool test --- include/layers/avg_pooling.cuh | 10 +++- include/layers/max_pooling.cuh | 1 - src/kernels/pooling.cu | 2 +- src/layers/avg_pooling.cu | 3 -- src/layers/max_pooling.cu | 3 -- test/layers/test_avg_pooling.cu | 70 +++++++++++++++++++++++++ tools/generate_test_results.py | 90 ++++++++++++++++++++++++--------- 7 files changed, 147 insertions(+), 32 deletions(-) create mode 100644 test/layers/test_avg_pooling.cu diff --git a/include/layers/avg_pooling.cuh b/include/layers/avg_pooling.cuh index 2a7416c..f992078 100644 --- a/include/layers/avg_pooling.cuh +++ b/include/layers/avg_pooling.cuh @@ -19,6 +19,15 @@ class AvgPooling2D : public SequentialLayer { float* forward(const float* d_input); + /** + * @brief Get the output width (/ height) of the layer + * + * @return int + */ + int getOutputSize() { + return outputSize; + } + private: int inputSize; int nChannels; @@ -26,7 +35,6 @@ class AvgPooling2D : public SequentialLayer { int stride; int outputSize; - int gridSize; float* d_output; diff --git a/include/layers/max_pooling.cuh b/include/layers/max_pooling.cuh index 05044c9..321412c 100644 --- a/include/layers/max_pooling.cuh +++ b/include/layers/max_pooling.cuh @@ -37,7 +37,6 @@ class MaxPooling2D : public SequentialLayer { int stride; int outputSize; - int gridSize; float* d_output; diff --git a/src/kernels/pooling.cu b/src/kernels/pooling.cu index baf59a7..97094cf 100644 --- a/src/kernels/pooling.cu +++ b/src/kernels/pooling.cu @@ -49,7 +49,7 @@ __global__ void Kernels::avg_pooling( int i = blockDim.y * blockIdx.y + threadIdx.y; int c = blockDim.z * blockIdx.z + threadIdx.z; - if (i >= inputSize || j >= inputSize || c >= nChannels) { + if (i >= outputSize || j >= outputSize || c >= outputSize) { return; } diff --git a/src/layers/avg_pooling.cu b/src/layers/avg_pooling.cu index 70d735d..aedcc13 100644 --- a/src/layers/avg_pooling.cu +++ b/src/layers/avg_pooling.cu @@ -24,9 +24,6 @@ AvgPooling2D::AvgPooling2D( CUDA_CHECK(cudaMalloc( (void**)&d_output, sizeof(float) * outputSize * outputSize * nChannels )); - - gridSize = - (outputSize * outputSize * nChannels + BLOCK_SIZE - 1) / BLOCK_SIZE; } AvgPooling2D::~AvgPooling2D() { diff --git a/src/layers/max_pooling.cu b/src/layers/max_pooling.cu index 1aad452..3f1f426 100644 --- a/src/layers/max_pooling.cu +++ b/src/layers/max_pooling.cu @@ -25,9 +25,6 @@ MaxPooling2D::MaxPooling2D( CUDA_CHECK(cudaMalloc( (void**)&d_output, sizeof(float) * outputSize * outputSize * nChannels )); - - gridSize = (outputSize * outputSize * nChannels + BLOCK_SIZE - 1) / BLOCK_SIZE; - } diff --git a/test/layers/test_avg_pooling.cu b/test/layers/test_avg_pooling.cu new file mode 100644 index 0000000..2fcacec --- /dev/null +++ b/test/layers/test_avg_pooling.cu @@ -0,0 +1,70 @@ +#include +#include + +#include + +#include "avg_pooling.cuh" + +TEST(AvgPoolingLayerTest, AvgPoolForwardTest) { + int inputSize = 4; + int nChannels = 2; + int poolingSize = 2; + int stride = 2; + + cudaError_t cudaStatus; + + std::vector input = { + // clang-format off + // Channel 0 + 0.573f, 0.619f, 0.732f, 0.055f, + 0.243f, 0.316f, 0.573f, 0.619f, + 0.712f, 0.055f, 0.243f, 0.316f, + 0.573f, 0.619f, 0.742f, 0.055f, + // Channel 1 + 0.473f, 0.919f, 0.107f, 0.073f, + 0.073f, 0.362f, 0.973f, 0.059f, + 0.473f, 0.455f, 0.283f, 0.416f, + 0.532f, 0.819f, 0.732f, 0.850f + // clang-format on + }; + + CUDANet::Layers::AvgPooling2D avgPoolingLayer( + inputSize, nChannels, poolingSize, stride, + CUDANet::Layers::ActivationType::NONE + ); + + float *d_input; + + cudaStatus = cudaMalloc( + (void **)&d_input, sizeof(float) * inputSize * inputSize * nChannels + ); + EXPECT_EQ(cudaStatus, cudaSuccess); + + cudaStatus = cudaMemcpy( + d_input, input.data(), + sizeof(float) * inputSize * inputSize * nChannels, + cudaMemcpyHostToDevice + ); + EXPECT_EQ(cudaStatus, cudaSuccess); + + float *d_output = avgPoolingLayer.forward(d_input); + + int outputSize = avgPoolingLayer.getOutputSize(); + + std::vector output(outputSize * outputSize * nChannels); + cudaStatus = cudaMemcpy( + output.data(), d_output, + sizeof(float) * outputSize * outputSize * nChannels, + cudaMemcpyDeviceToHost + ); + EXPECT_EQ(cudaStatus, cudaSuccess); + + std::vector expected = {0.43775f, 0.49475f, 0.48975f, 0.339f, 0.45675f, 0.303f, 0.56975f, 0.57025f}; + + for (int i = 0; i < output.size(); ++i) { + EXPECT_NEAR(expected[i], output[i], 1e-4); + } + + cudaFree(d_input); + cudaFree(d_output); +} diff --git a/tools/generate_test_results.py b/tools/generate_test_results.py index bfb71ad..f70fdaa 100644 --- a/tools/generate_test_results.py +++ b/tools/generate_test_results.py @@ -1,8 +1,20 @@ import torch -def _conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights): - conv2d = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) +def _conv2d(in_channels, + out_channels, + kernel_size, + stride, + padding, + inputs, + weights): + + conv2d = torch.nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=False) conv2d.weight = torch.nn.Parameter(weights) output = conv2d(inputs) @@ -11,6 +23,7 @@ def _conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, wei output = torch.flatten(output) return output + def _print_cpp_vector(vector): print("std::vector expected = {", end="") for i in range(len(vector)): @@ -20,6 +33,19 @@ def _print_cpp_vector(vector): print("};") +def _get_pool_input(): + return torch.tensor([ + 0.573, 0.619, 0.732, 0.055, + 0.243, 0.316, 0.573, 0.619, + 0.712, 0.055, 0.243, 0.316, + 0.573, 0.619, 0.742, 0.055, + 0.473, 0.919, 0.107, 0.073, + 0.073, 0.362, 0.973, 0.059, + 0.473, 0.455, 0.283, 0.416, + 0.532, 0.819, 0.732, 0.850 + ]).reshape(1, 2, 4, 4) + + def gen_convd_padded_test_result(): in_channels = 3 @@ -68,9 +94,16 @@ def gen_convd_padded_test_result(): 0.011, 0.345, 0.678 ], dtype=torch.float).reshape(2, 3, 3, 3) - output = _conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights) + output = _conv2d(in_channels, + out_channels, + kernel_size, + stride, + padding, + inputs, + weights) _print_cpp_vector(output) + def gen_convd_strided_test_result(): in_channels = 2 @@ -78,7 +111,7 @@ def gen_convd_strided_test_result(): kernel_size = 3 stride = 2 padding = 3 - + input = torch.tensor([ 0.946, 0.879, 0.382, 0.542, 0.453, 0.128, 0.860, 0.778, 0.049, 0.974, @@ -106,9 +139,16 @@ def gen_convd_strided_test_result(): 0.939, 0.891, 0.006 ], dtype=torch.float).reshape(2, 2, 3, 3) - output = _conv2d(in_channels, out_channels, kernel_size, stride, padding, input, weights) + output = _conv2d(in_channels, + out_channels, + kernel_size, + stride, + padding, + input, + weights) _print_cpp_vector(output) + def gen_softmax_test_result(): input = torch.tensor([ 0.573, 0.619, 0.732, 0.055, 0.243 @@ -117,17 +157,9 @@ def gen_softmax_test_result(): output = torch.nn.Softmax(dim=0)(input) _print_cpp_vector(output) + def gen_max_pool_test_result(): - input = torch.tensor([ - 0.573, 0.619, 0.732, 0.055, - 0.243, 0.316, 0.573, 0.619, - 0.712, 0.055, 0.243, 0.316, - 0.573, 0.619, 0.742, 0.055, - 0.473, 0.919, 0.107, 0.073, - 0.073, 0.362, 0.973, 0.059, - 0.473, 0.455, 0.283, 0.416, - 0.532, 0.819, 0.732, 0.850 - ]).reshape(1, 2, 4, 4) + input = _get_pool_input() output = torch.nn.MaxPool2d(kernel_size=2, stride=2)(input) output = torch.flatten(output) @@ -135,13 +167,25 @@ def gen_max_pool_test_result(): _print_cpp_vector(output) +def gen_avg_pool_test_result(): + + input = _get_pool_input() + + output = torch.nn.AvgPool2d(kernel_size=2, stride=2)(input) + output = torch.flatten(output) + + _print_cpp_vector(output) + + if __name__ == "__main__": - # print("Generating test results...") - # print("Padded convolution test:") - # gen_convd_padded_test_result() - # print("Strided convolution test:") - # gen_convd_strided_test_result() - # print("Softmax test:") - # gen_softmax_test_result() + print("Generating test results...") + print("Padded convolution test:") + gen_convd_padded_test_result() + print("Strided convolution test:") + gen_convd_strided_test_result() + print("Softmax test:") + gen_softmax_test_result() print("Max pool test:") - gen_max_pool_test_result() \ No newline at end of file + gen_max_pool_test_result() + print("Avg pool test:") + gen_avg_pool_test_result()