From dfff0360d92eaf9116bbe8ea56ee55b2a6989798 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 20 Mar 2024 21:44:04 +0100 Subject: [PATCH] Implement max pooling test --- CMakeLists.txt | 4 +- include/kernels/pooling.cuh | 2 + include/layers/max_pooling.cuh | 11 +++++- src/kernels/pooling.cu | 8 ++-- src/layers/avg_pooling.cu | 2 +- src/layers/max_pooling.cu | 4 +- test/CMakeLists.txt | 14 +++---- test/layers/test_max_pooling.cu | 70 +++++++++++++++++++++++++++++++++ tools/generate_test_results.py | 51 ++++++++++++++++-------- 9 files changed, 134 insertions(+), 32 deletions(-) create mode 100644 test/layers/test_max_pooling.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index fce8bea..792d825 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,8 @@ file(GLOB_RECURSE LIBRARY_SOURCES src/*.cu src/utils/*.cu src/kernels/*.cu - src/layers/*.cu) + src/layers/*.cu +) set(LIBRARY_SOURCES ${LIBRARY_SOURCES} @@ -19,7 +20,6 @@ set(LIBRARY_SOURCES set(CMAKE_CUDA_ARCHITECTURES 75) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -# set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -arch=sm_75) # Build static library add_library(${PROJECT_NAME} STATIC ${LIBRARY_SOURCES}) diff --git a/include/kernels/pooling.cuh b/include/kernels/pooling.cuh index 6a74010..6f4e3b8 100644 --- a/include/kernels/pooling.cuh +++ b/include/kernels/pooling.cuh @@ -9,6 +9,7 @@ __global__ void max_pooling( const float* __restrict__ d_input, float* __restrict__ d_output, const int inputSize, + const int outputSize, const int nChannels, const int poolingSize, const int stride @@ -18,6 +19,7 @@ __global__ void avg_pooling( const float* __restrict__ d_input, float* __restrict__ d_output, const int inputSize, + const int outputSize, const int nChannels, const int poolingSize, const int stride diff --git a/include/layers/max_pooling.cuh b/include/layers/max_pooling.cuh index 90a2b21..05044c9 100644 --- a/include/layers/max_pooling.cuh +++ b/include/layers/max_pooling.cuh @@ -3,8 +3,8 @@ #include -#include "layer.cuh" #include "activation.cuh" +#include "layer.cuh" namespace CUDANet::Layers { @@ -21,6 +21,15 @@ class MaxPooling2D : public SequentialLayer { float* forward(const float* d_input); + /** + * @brief Get the output width (/ height) of the layer + * + * @return int + */ + int getOutputSize() { + return outputSize; + } + private: int inputSize; int nChannels; diff --git a/src/kernels/pooling.cu b/src/kernels/pooling.cu index d1a5066..baf59a7 100644 --- a/src/kernels/pooling.cu +++ b/src/kernels/pooling.cu @@ -7,6 +7,7 @@ __global__ void Kernels::max_pooling( const float* __restrict__ d_input, float* __restrict__ d_output, const int inputSize, + const int outputSize, const int nChannels, const int poolingSize, const int stride @@ -15,7 +16,7 @@ __global__ void Kernels::max_pooling( int i = blockDim.y * blockIdx.y + threadIdx.y; int c = blockDim.z * blockIdx.z + threadIdx.z; - if (i >= inputSize || j >= inputSize || c >= nChannels) { + if (i >= outputSize || j >= outputSize || c >= nChannels) { return; } @@ -32,13 +33,14 @@ __global__ void Kernels::max_pooling( } } - d_output[c * inputSize * inputSize + i * inputSize + j] = max; + d_output[c * outputSize * outputSize + i * outputSize + j] = max; } __global__ void Kernels::avg_pooling( const float* __restrict__ d_input, float* __restrict__ d_output, const int inputSize, + const int outputSize, const int nChannels, const int poolingSize, const int stride @@ -62,6 +64,6 @@ __global__ void Kernels::avg_pooling( } } - d_output[c * inputSize * inputSize + i * inputSize + j] = + d_output[c * outputSize * outputSize + i * outputSize + j] = sum / (poolingSize * poolingSize); } \ No newline at end of file diff --git a/src/layers/avg_pooling.cu b/src/layers/avg_pooling.cu index 5c61566..70d735d 100644 --- a/src/layers/avg_pooling.cu +++ b/src/layers/avg_pooling.cu @@ -43,7 +43,7 @@ float* AvgPooling2D::forward(const float* d_input) { ); Kernels::avg_pooling<<>>( - d_input, d_output, inputSize, nChannels, poolingSize, stride + d_input, d_output, inputSize, outputSize, nChannels, poolingSize, stride ); return d_output; diff --git a/src/layers/max_pooling.cu b/src/layers/max_pooling.cu index 3dac272..1aad452 100644 --- a/src/layers/max_pooling.cu +++ b/src/layers/max_pooling.cu @@ -15,7 +15,7 @@ MaxPooling2D::MaxPooling2D( : inputSize(inputSize), nChannels(nChannels), poolingSize(poolingSize), stride(stride) { - outputSize = (inputSize - poolingSize) / stride + 1; + outputSize = (inputSize - 1) / stride + 1; activation = Activation( activationType, outputSize * outputSize * nChannels @@ -46,7 +46,7 @@ float* MaxPooling2D::forward(const float* d_input) { ); Kernels::max_pooling<<>>( - d_input, d_output, inputSize, nChannels, poolingSize, stride + d_input, d_output, inputSize, outputSize, nChannels, poolingSize, stride ); return d_output; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 27a4b6f..4a780fd 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,15 +1,15 @@ find_package(GTest REQUIRED) include_directories(${GTEST_INCLUDE_DIRS}) +file(GLOB_RECURSE TEST_SOURCES + *.cu + kernels/*.cu + layers/*.cu +) + add_executable(test_main EXCLUDE_FROM_ALL - layers/test_activation.cu - layers/test_concat.cu - layers/test_conv2d.cu - layers/test_dense.cu - layers/test_input.cu - kernels/test_activation_functions.cu - kernels/test_matmul.cu + ${TEST_SOURCES} ) target_link_libraries(test_main ${GTEST_BOTH_LIBRARIES} CUDANet) diff --git a/test/layers/test_max_pooling.cu b/test/layers/test_max_pooling.cu new file mode 100644 index 0000000..b704fa7 --- /dev/null +++ b/test/layers/test_max_pooling.cu @@ -0,0 +1,70 @@ +#include +#include + +#include + +#include "max_pooling.cuh" + +TEST(MaxPoolingLayerTest, MaxPoolForwardTest) { + int inputSize = 4; + int nChannels = 2; + int poolingSize = 2; + int stride = 2; + + cudaError_t cudaStatus; + + std::vector input = { + // clang-format off + // Channel 0 + 0.573f, 0.619f, 0.732f, 0.055f, + 0.243f, 0.316f, 0.573f, 0.619f, + 0.712f, 0.055f, 0.243f, 0.316f, + 0.573f, 0.619f, 0.742f, 0.055f, + // Channel 1 + 0.473f, 0.919f, 0.107f, 0.073f, + 0.073f, 0.362f, 0.973f, 0.059f, + 0.473f, 0.455f, 0.283f, 0.416f, + 0.532f, 0.819f, 0.732f, 0.850f + // clang-format on + }; + + CUDANet::Layers::MaxPooling2D maxPoolingLayer( + inputSize, nChannels, poolingSize, stride, + CUDANet::Layers::ActivationType::NONE + ); + + float *d_input; + + cudaStatus = cudaMalloc( + (void **)&d_input, sizeof(float) * inputSize * inputSize * nChannels + ); + EXPECT_EQ(cudaStatus, cudaSuccess); + + cudaStatus = cudaMemcpy( + d_input, input.data(), + sizeof(float) * inputSize * inputSize * nChannels, + cudaMemcpyHostToDevice + ); + EXPECT_EQ(cudaStatus, cudaSuccess); + + float *d_output = maxPoolingLayer.forward(d_input); + + int outputSize = maxPoolingLayer.getOutputSize(); + + std::vector output(outputSize * outputSize * nChannels); + cudaStatus = cudaMemcpy( + output.data(), d_output, + sizeof(float) * outputSize * outputSize * nChannels, + cudaMemcpyDeviceToHost + ); + EXPECT_EQ(cudaStatus, cudaSuccess); + + std::vector expected = {0.619f, 0.732f, 0.712f, 0.742f, 0.919f, 0.973f, 0.819f, 0.85f}; + + for (int i = 0; i < output.size(); ++i) { + EXPECT_FLOAT_EQ(expected[i], output[i]); + } + + cudaFree(d_input); + cudaFree(d_output); +} diff --git a/tools/generate_test_results.py b/tools/generate_test_results.py index e63b29f..bfb71ad 100644 --- a/tools/generate_test_results.py +++ b/tools/generate_test_results.py @@ -1,6 +1,6 @@ import torch -def conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights): +def _conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights): conv2d = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) conv2d.weight = torch.nn.Parameter(weights) @@ -11,7 +11,7 @@ def conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weig output = torch.flatten(output) return output -def print_cpp_vector(vector): +def _print_cpp_vector(vector): print("std::vector expected = {", end="") for i in range(len(vector)): if i != 0: @@ -20,7 +20,7 @@ def print_cpp_vector(vector): print("};") -def gen_padded_test_result(): +def gen_convd_padded_test_result(): in_channels = 3 out_channels = 2 @@ -68,10 +68,10 @@ def gen_padded_test_result(): 0.011, 0.345, 0.678 ], dtype=torch.float).reshape(2, 3, 3, 3) - output = conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights) - print_cpp_vector(output) + output = _conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights) + _print_cpp_vector(output) -def gen_strided_test_result(): +def gen_convd_strided_test_result(): in_channels = 2 out_channels = 2 @@ -106,8 +106,8 @@ def gen_strided_test_result(): 0.939, 0.891, 0.006 ], dtype=torch.float).reshape(2, 2, 3, 3) - output = conv2d(in_channels, out_channels, kernel_size, stride, padding, input, weights) - print_cpp_vector(output) + output = _conv2d(in_channels, out_channels, kernel_size, stride, padding, input, weights) + _print_cpp_vector(output) def gen_softmax_test_result(): input = torch.tensor([ @@ -115,14 +115,33 @@ def gen_softmax_test_result(): ]) output = torch.nn.Softmax(dim=0)(input) - print_cpp_vector(output) + _print_cpp_vector(output) + +def gen_max_pool_test_result(): + input = torch.tensor([ + 0.573, 0.619, 0.732, 0.055, + 0.243, 0.316, 0.573, 0.619, + 0.712, 0.055, 0.243, 0.316, + 0.573, 0.619, 0.742, 0.055, + 0.473, 0.919, 0.107, 0.073, + 0.073, 0.362, 0.973, 0.059, + 0.473, 0.455, 0.283, 0.416, + 0.532, 0.819, 0.732, 0.850 + ]).reshape(1, 2, 4, 4) + + output = torch.nn.MaxPool2d(kernel_size=2, stride=2)(input) + output = torch.flatten(output) + + _print_cpp_vector(output) if __name__ == "__main__": - print("Generating test results...") - print("Padded convolution test:") - gen_padded_test_result() - print("Strided convolution test:") - gen_strided_test_result() - print("Softmax test:") - gen_softmax_test_result() \ No newline at end of file + # print("Generating test results...") + # print("Padded convolution test:") + # gen_convd_padded_test_result() + # print("Strided convolution test:") + # gen_convd_strided_test_result() + # print("Softmax test:") + # gen_softmax_test_result() + print("Max pool test:") + gen_max_pool_test_result() \ No newline at end of file