Implement max pooling test

This commit is contained in:
2024-03-20 21:44:04 +01:00
parent c062e89972
commit dfff0360d9
9 changed files with 134 additions and 32 deletions

View File

@@ -11,7 +11,8 @@ file(GLOB_RECURSE LIBRARY_SOURCES
src/*.cu
src/utils/*.cu
src/kernels/*.cu
src/layers/*.cu)
src/layers/*.cu
)
set(LIBRARY_SOURCES
${LIBRARY_SOURCES}
@@ -19,7 +20,6 @@ set(LIBRARY_SOURCES
set(CMAKE_CUDA_ARCHITECTURES 75)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -arch=sm_75)
# Build static library
add_library(${PROJECT_NAME} STATIC ${LIBRARY_SOURCES})

View File

@@ -9,6 +9,7 @@ __global__ void max_pooling(
const float* __restrict__ d_input,
float* __restrict__ d_output,
const int inputSize,
const int outputSize,
const int nChannels,
const int poolingSize,
const int stride
@@ -18,6 +19,7 @@ __global__ void avg_pooling(
const float* __restrict__ d_input,
float* __restrict__ d_output,
const int inputSize,
const int outputSize,
const int nChannels,
const int poolingSize,
const int stride

View File

@@ -3,8 +3,8 @@
#include <cuda_runtime.h>
#include "layer.cuh"
#include "activation.cuh"
#include "layer.cuh"
namespace CUDANet::Layers {
@@ -21,6 +21,15 @@ class MaxPooling2D : public SequentialLayer {
float* forward(const float* d_input);
/**
* @brief Get the output width (/ height) of the layer
*
* @return int
*/
int getOutputSize() {
return outputSize;
}
private:
int inputSize;
int nChannels;

View File

@@ -7,6 +7,7 @@ __global__ void Kernels::max_pooling(
const float* __restrict__ d_input,
float* __restrict__ d_output,
const int inputSize,
const int outputSize,
const int nChannels,
const int poolingSize,
const int stride
@@ -15,7 +16,7 @@ __global__ void Kernels::max_pooling(
int i = blockDim.y * blockIdx.y + threadIdx.y;
int c = blockDim.z * blockIdx.z + threadIdx.z;
if (i >= inputSize || j >= inputSize || c >= nChannels) {
if (i >= outputSize || j >= outputSize || c >= nChannels) {
return;
}
@@ -32,13 +33,14 @@ __global__ void Kernels::max_pooling(
}
}
d_output[c * inputSize * inputSize + i * inputSize + j] = max;
d_output[c * outputSize * outputSize + i * outputSize + j] = max;
}
__global__ void Kernels::avg_pooling(
const float* __restrict__ d_input,
float* __restrict__ d_output,
const int inputSize,
const int outputSize,
const int nChannels,
const int poolingSize,
const int stride
@@ -62,6 +64,6 @@ __global__ void Kernels::avg_pooling(
}
}
d_output[c * inputSize * inputSize + i * inputSize + j] =
d_output[c * outputSize * outputSize + i * outputSize + j] =
sum / (poolingSize * poolingSize);
}

View File

@@ -43,7 +43,7 @@ float* AvgPooling2D::forward(const float* d_input) {
);
Kernels::avg_pooling<<<grid, block>>>(
d_input, d_output, inputSize, nChannels, poolingSize, stride
d_input, d_output, inputSize, outputSize, nChannels, poolingSize, stride
);
return d_output;

View File

@@ -15,7 +15,7 @@ MaxPooling2D::MaxPooling2D(
: inputSize(inputSize), nChannels(nChannels), poolingSize(poolingSize), stride(stride) {
outputSize = (inputSize - poolingSize) / stride + 1;
outputSize = (inputSize - 1) / stride + 1;
activation = Activation(
activationType, outputSize * outputSize * nChannels
@@ -46,7 +46,7 @@ float* MaxPooling2D::forward(const float* d_input) {
);
Kernels::max_pooling<<<grid, block>>>(
d_input, d_output, inputSize, nChannels, poolingSize, stride
d_input, d_output, inputSize, outputSize, nChannels, poolingSize, stride
);
return d_output;

View File

@@ -1,15 +1,15 @@
find_package(GTest REQUIRED)
include_directories(${GTEST_INCLUDE_DIRS})
file(GLOB_RECURSE TEST_SOURCES
*.cu
kernels/*.cu
layers/*.cu
)
add_executable(test_main
EXCLUDE_FROM_ALL
layers/test_activation.cu
layers/test_concat.cu
layers/test_conv2d.cu
layers/test_dense.cu
layers/test_input.cu
kernels/test_activation_functions.cu
kernels/test_matmul.cu
${TEST_SOURCES}
)
target_link_libraries(test_main ${GTEST_BOTH_LIBRARIES} CUDANet)

View File

@@ -0,0 +1,70 @@
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <vector>
#include "max_pooling.cuh"
TEST(MaxPoolingLayerTest, MaxPoolForwardTest) {
int inputSize = 4;
int nChannels = 2;
int poolingSize = 2;
int stride = 2;
cudaError_t cudaStatus;
std::vector<float> input = {
// clang-format off
// Channel 0
0.573f, 0.619f, 0.732f, 0.055f,
0.243f, 0.316f, 0.573f, 0.619f,
0.712f, 0.055f, 0.243f, 0.316f,
0.573f, 0.619f, 0.742f, 0.055f,
// Channel 1
0.473f, 0.919f, 0.107f, 0.073f,
0.073f, 0.362f, 0.973f, 0.059f,
0.473f, 0.455f, 0.283f, 0.416f,
0.532f, 0.819f, 0.732f, 0.850f
// clang-format on
};
CUDANet::Layers::MaxPooling2D maxPoolingLayer(
inputSize, nChannels, poolingSize, stride,
CUDANet::Layers::ActivationType::NONE
);
float *d_input;
cudaStatus = cudaMalloc(
(void **)&d_input, sizeof(float) * inputSize * inputSize * nChannels
);
EXPECT_EQ(cudaStatus, cudaSuccess);
cudaStatus = cudaMemcpy(
d_input, input.data(),
sizeof(float) * inputSize * inputSize * nChannels,
cudaMemcpyHostToDevice
);
EXPECT_EQ(cudaStatus, cudaSuccess);
float *d_output = maxPoolingLayer.forward(d_input);
int outputSize = maxPoolingLayer.getOutputSize();
std::vector<float> output(outputSize * outputSize * nChannels);
cudaStatus = cudaMemcpy(
output.data(), d_output,
sizeof(float) * outputSize * outputSize * nChannels,
cudaMemcpyDeviceToHost
);
EXPECT_EQ(cudaStatus, cudaSuccess);
std::vector<float> expected = {0.619f, 0.732f, 0.712f, 0.742f, 0.919f, 0.973f, 0.819f, 0.85f};
for (int i = 0; i < output.size(); ++i) {
EXPECT_FLOAT_EQ(expected[i], output[i]);
}
cudaFree(d_input);
cudaFree(d_output);
}

View File

@@ -1,6 +1,6 @@
import torch
def conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights):
def _conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights):
conv2d = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
conv2d.weight = torch.nn.Parameter(weights)
@@ -11,7 +11,7 @@ def conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weig
output = torch.flatten(output)
return output
def print_cpp_vector(vector):
def _print_cpp_vector(vector):
print("std::vector<float> expected = {", end="")
for i in range(len(vector)):
if i != 0:
@@ -20,7 +20,7 @@ def print_cpp_vector(vector):
print("};")
def gen_padded_test_result():
def gen_convd_padded_test_result():
in_channels = 3
out_channels = 2
@@ -68,10 +68,10 @@ def gen_padded_test_result():
0.011, 0.345, 0.678
], dtype=torch.float).reshape(2, 3, 3, 3)
output = conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights)
print_cpp_vector(output)
output = _conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights)
_print_cpp_vector(output)
def gen_strided_test_result():
def gen_convd_strided_test_result():
in_channels = 2
out_channels = 2
@@ -106,8 +106,8 @@ def gen_strided_test_result():
0.939, 0.891, 0.006
], dtype=torch.float).reshape(2, 2, 3, 3)
output = conv2d(in_channels, out_channels, kernel_size, stride, padding, input, weights)
print_cpp_vector(output)
output = _conv2d(in_channels, out_channels, kernel_size, stride, padding, input, weights)
_print_cpp_vector(output)
def gen_softmax_test_result():
input = torch.tensor([
@@ -115,14 +115,33 @@ def gen_softmax_test_result():
])
output = torch.nn.Softmax(dim=0)(input)
print_cpp_vector(output)
_print_cpp_vector(output)
def gen_max_pool_test_result():
input = torch.tensor([
0.573, 0.619, 0.732, 0.055,
0.243, 0.316, 0.573, 0.619,
0.712, 0.055, 0.243, 0.316,
0.573, 0.619, 0.742, 0.055,
0.473, 0.919, 0.107, 0.073,
0.073, 0.362, 0.973, 0.059,
0.473, 0.455, 0.283, 0.416,
0.532, 0.819, 0.732, 0.850
]).reshape(1, 2, 4, 4)
output = torch.nn.MaxPool2d(kernel_size=2, stride=2)(input)
output = torch.flatten(output)
_print_cpp_vector(output)
if __name__ == "__main__":
print("Generating test results...")
print("Padded convolution test:")
gen_padded_test_result()
print("Strided convolution test:")
gen_strided_test_result()
print("Softmax test:")
gen_softmax_test_result()
# print("Generating test results...")
# print("Padded convolution test:")
# gen_convd_padded_test_result()
# print("Strided convolution test:")
# gen_convd_strided_test_result()
# print("Softmax test:")
# gen_softmax_test_result()
print("Max pool test:")
gen_max_pool_test_result()