Implement avg pool test

This commit is contained in:
2024-03-20 21:57:22 +01:00
parent dfff0360d9
commit 6f4cdf3792
7 changed files with 147 additions and 32 deletions

View File

@@ -19,6 +19,15 @@ class AvgPooling2D : public SequentialLayer {
float* forward(const float* d_input);
/**
* @brief Get the output width (/ height) of the layer
*
* @return int
*/
int getOutputSize() {
return outputSize;
}
private:
int inputSize;
int nChannels;
@@ -26,7 +35,6 @@ class AvgPooling2D : public SequentialLayer {
int stride;
int outputSize;
int gridSize;
float* d_output;

View File

@@ -37,7 +37,6 @@ class MaxPooling2D : public SequentialLayer {
int stride;
int outputSize;
int gridSize;
float* d_output;

View File

@@ -49,7 +49,7 @@ __global__ void Kernels::avg_pooling(
int i = blockDim.y * blockIdx.y + threadIdx.y;
int c = blockDim.z * blockIdx.z + threadIdx.z;
if (i >= inputSize || j >= inputSize || c >= nChannels) {
if (i >= outputSize || j >= outputSize || c >= outputSize) {
return;
}

View File

@@ -24,9 +24,6 @@ AvgPooling2D::AvgPooling2D(
CUDA_CHECK(cudaMalloc(
(void**)&d_output, sizeof(float) * outputSize * outputSize * nChannels
));
gridSize =
(outputSize * outputSize * nChannels + BLOCK_SIZE - 1) / BLOCK_SIZE;
}
AvgPooling2D::~AvgPooling2D() {

View File

@@ -25,9 +25,6 @@ MaxPooling2D::MaxPooling2D(
CUDA_CHECK(cudaMalloc(
(void**)&d_output, sizeof(float) * outputSize * outputSize * nChannels
));
gridSize = (outputSize * outputSize * nChannels + BLOCK_SIZE - 1) / BLOCK_SIZE;
}

View File

@@ -0,0 +1,70 @@
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <vector>
#include "avg_pooling.cuh"
TEST(AvgPoolingLayerTest, AvgPoolForwardTest) {
int inputSize = 4;
int nChannels = 2;
int poolingSize = 2;
int stride = 2;
cudaError_t cudaStatus;
std::vector<float> input = {
// clang-format off
// Channel 0
0.573f, 0.619f, 0.732f, 0.055f,
0.243f, 0.316f, 0.573f, 0.619f,
0.712f, 0.055f, 0.243f, 0.316f,
0.573f, 0.619f, 0.742f, 0.055f,
// Channel 1
0.473f, 0.919f, 0.107f, 0.073f,
0.073f, 0.362f, 0.973f, 0.059f,
0.473f, 0.455f, 0.283f, 0.416f,
0.532f, 0.819f, 0.732f, 0.850f
// clang-format on
};
CUDANet::Layers::AvgPooling2D avgPoolingLayer(
inputSize, nChannels, poolingSize, stride,
CUDANet::Layers::ActivationType::NONE
);
float *d_input;
cudaStatus = cudaMalloc(
(void **)&d_input, sizeof(float) * inputSize * inputSize * nChannels
);
EXPECT_EQ(cudaStatus, cudaSuccess);
cudaStatus = cudaMemcpy(
d_input, input.data(),
sizeof(float) * inputSize * inputSize * nChannels,
cudaMemcpyHostToDevice
);
EXPECT_EQ(cudaStatus, cudaSuccess);
float *d_output = avgPoolingLayer.forward(d_input);
int outputSize = avgPoolingLayer.getOutputSize();
std::vector<float> output(outputSize * outputSize * nChannels);
cudaStatus = cudaMemcpy(
output.data(), d_output,
sizeof(float) * outputSize * outputSize * nChannels,
cudaMemcpyDeviceToHost
);
EXPECT_EQ(cudaStatus, cudaSuccess);
std::vector<float> expected = {0.43775f, 0.49475f, 0.48975f, 0.339f, 0.45675f, 0.303f, 0.56975f, 0.57025f};
for (int i = 0; i < output.size(); ++i) {
EXPECT_NEAR(expected[i], output[i], 1e-4);
}
cudaFree(d_input);
cudaFree(d_output);
}

View File

@@ -1,8 +1,20 @@
import torch
def _conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights):
conv2d = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
def _conv2d(in_channels,
out_channels,
kernel_size,
stride,
padding,
inputs,
weights):
conv2d = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False)
conv2d.weight = torch.nn.Parameter(weights)
output = conv2d(inputs)
@@ -11,6 +23,7 @@ def _conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, wei
output = torch.flatten(output)
return output
def _print_cpp_vector(vector):
print("std::vector<float> expected = {", end="")
for i in range(len(vector)):
@@ -20,6 +33,19 @@ def _print_cpp_vector(vector):
print("};")
def _get_pool_input():
return torch.tensor([
0.573, 0.619, 0.732, 0.055,
0.243, 0.316, 0.573, 0.619,
0.712, 0.055, 0.243, 0.316,
0.573, 0.619, 0.742, 0.055,
0.473, 0.919, 0.107, 0.073,
0.073, 0.362, 0.973, 0.059,
0.473, 0.455, 0.283, 0.416,
0.532, 0.819, 0.732, 0.850
]).reshape(1, 2, 4, 4)
def gen_convd_padded_test_result():
in_channels = 3
@@ -68,9 +94,16 @@ def gen_convd_padded_test_result():
0.011, 0.345, 0.678
], dtype=torch.float).reshape(2, 3, 3, 3)
output = _conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights)
output = _conv2d(in_channels,
out_channels,
kernel_size,
stride,
padding,
inputs,
weights)
_print_cpp_vector(output)
def gen_convd_strided_test_result():
in_channels = 2
@@ -78,7 +111,7 @@ def gen_convd_strided_test_result():
kernel_size = 3
stride = 2
padding = 3
input = torch.tensor([
0.946, 0.879, 0.382, 0.542, 0.453,
0.128, 0.860, 0.778, 0.049, 0.974,
@@ -106,9 +139,16 @@ def gen_convd_strided_test_result():
0.939, 0.891, 0.006
], dtype=torch.float).reshape(2, 2, 3, 3)
output = _conv2d(in_channels, out_channels, kernel_size, stride, padding, input, weights)
output = _conv2d(in_channels,
out_channels,
kernel_size,
stride,
padding,
input,
weights)
_print_cpp_vector(output)
def gen_softmax_test_result():
input = torch.tensor([
0.573, 0.619, 0.732, 0.055, 0.243
@@ -117,17 +157,9 @@ def gen_softmax_test_result():
output = torch.nn.Softmax(dim=0)(input)
_print_cpp_vector(output)
def gen_max_pool_test_result():
input = torch.tensor([
0.573, 0.619, 0.732, 0.055,
0.243, 0.316, 0.573, 0.619,
0.712, 0.055, 0.243, 0.316,
0.573, 0.619, 0.742, 0.055,
0.473, 0.919, 0.107, 0.073,
0.073, 0.362, 0.973, 0.059,
0.473, 0.455, 0.283, 0.416,
0.532, 0.819, 0.732, 0.850
]).reshape(1, 2, 4, 4)
input = _get_pool_input()
output = torch.nn.MaxPool2d(kernel_size=2, stride=2)(input)
output = torch.flatten(output)
@@ -135,13 +167,25 @@ def gen_max_pool_test_result():
_print_cpp_vector(output)
def gen_avg_pool_test_result():
input = _get_pool_input()
output = torch.nn.AvgPool2d(kernel_size=2, stride=2)(input)
output = torch.flatten(output)
_print_cpp_vector(output)
if __name__ == "__main__":
# print("Generating test results...")
# print("Padded convolution test:")
# gen_convd_padded_test_result()
# print("Strided convolution test:")
# gen_convd_strided_test_result()
# print("Softmax test:")
# gen_softmax_test_result()
print("Generating test results...")
print("Padded convolution test:")
gen_convd_padded_test_result()
print("Strided convolution test:")
gen_convd_strided_test_result()
print("Softmax test:")
gen_softmax_test_result()
print("Max pool test:")
gen_max_pool_test_result()
gen_max_pool_test_result()
print("Avg pool test:")
gen_avg_pool_test_result()