mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Implement max pooling test
This commit is contained in:
@@ -11,7 +11,8 @@ file(GLOB_RECURSE LIBRARY_SOURCES
|
||||
src/*.cu
|
||||
src/utils/*.cu
|
||||
src/kernels/*.cu
|
||||
src/layers/*.cu)
|
||||
src/layers/*.cu
|
||||
)
|
||||
|
||||
set(LIBRARY_SOURCES
|
||||
${LIBRARY_SOURCES}
|
||||
@@ -19,7 +20,6 @@ set(LIBRARY_SOURCES
|
||||
|
||||
set(CMAKE_CUDA_ARCHITECTURES 75)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
# set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -arch=sm_75)
|
||||
|
||||
# Build static library
|
||||
add_library(${PROJECT_NAME} STATIC ${LIBRARY_SOURCES})
|
||||
|
||||
@@ -9,6 +9,7 @@ __global__ void max_pooling(
|
||||
const float* __restrict__ d_input,
|
||||
float* __restrict__ d_output,
|
||||
const int inputSize,
|
||||
const int outputSize,
|
||||
const int nChannels,
|
||||
const int poolingSize,
|
||||
const int stride
|
||||
@@ -18,6 +19,7 @@ __global__ void avg_pooling(
|
||||
const float* __restrict__ d_input,
|
||||
float* __restrict__ d_output,
|
||||
const int inputSize,
|
||||
const int outputSize,
|
||||
const int nChannels,
|
||||
const int poolingSize,
|
||||
const int stride
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "layer.cuh"
|
||||
#include "activation.cuh"
|
||||
#include "layer.cuh"
|
||||
|
||||
namespace CUDANet::Layers {
|
||||
|
||||
@@ -21,6 +21,15 @@ class MaxPooling2D : public SequentialLayer {
|
||||
|
||||
float* forward(const float* d_input);
|
||||
|
||||
/**
|
||||
* @brief Get the output width (/ height) of the layer
|
||||
*
|
||||
* @return int
|
||||
*/
|
||||
int getOutputSize() {
|
||||
return outputSize;
|
||||
}
|
||||
|
||||
private:
|
||||
int inputSize;
|
||||
int nChannels;
|
||||
|
||||
@@ -7,6 +7,7 @@ __global__ void Kernels::max_pooling(
|
||||
const float* __restrict__ d_input,
|
||||
float* __restrict__ d_output,
|
||||
const int inputSize,
|
||||
const int outputSize,
|
||||
const int nChannels,
|
||||
const int poolingSize,
|
||||
const int stride
|
||||
@@ -15,7 +16,7 @@ __global__ void Kernels::max_pooling(
|
||||
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
int c = blockDim.z * blockIdx.z + threadIdx.z;
|
||||
|
||||
if (i >= inputSize || j >= inputSize || c >= nChannels) {
|
||||
if (i >= outputSize || j >= outputSize || c >= nChannels) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -32,13 +33,14 @@ __global__ void Kernels::max_pooling(
|
||||
}
|
||||
}
|
||||
|
||||
d_output[c * inputSize * inputSize + i * inputSize + j] = max;
|
||||
d_output[c * outputSize * outputSize + i * outputSize + j] = max;
|
||||
}
|
||||
|
||||
__global__ void Kernels::avg_pooling(
|
||||
const float* __restrict__ d_input,
|
||||
float* __restrict__ d_output,
|
||||
const int inputSize,
|
||||
const int outputSize,
|
||||
const int nChannels,
|
||||
const int poolingSize,
|
||||
const int stride
|
||||
@@ -62,6 +64,6 @@ __global__ void Kernels::avg_pooling(
|
||||
}
|
||||
}
|
||||
|
||||
d_output[c * inputSize * inputSize + i * inputSize + j] =
|
||||
d_output[c * outputSize * outputSize + i * outputSize + j] =
|
||||
sum / (poolingSize * poolingSize);
|
||||
}
|
||||
@@ -43,7 +43,7 @@ float* AvgPooling2D::forward(const float* d_input) {
|
||||
);
|
||||
|
||||
Kernels::avg_pooling<<<grid, block>>>(
|
||||
d_input, d_output, inputSize, nChannels, poolingSize, stride
|
||||
d_input, d_output, inputSize, outputSize, nChannels, poolingSize, stride
|
||||
);
|
||||
|
||||
return d_output;
|
||||
|
||||
@@ -15,7 +15,7 @@ MaxPooling2D::MaxPooling2D(
|
||||
: inputSize(inputSize), nChannels(nChannels), poolingSize(poolingSize), stride(stride) {
|
||||
|
||||
|
||||
outputSize = (inputSize - poolingSize) / stride + 1;
|
||||
outputSize = (inputSize - 1) / stride + 1;
|
||||
|
||||
activation = Activation(
|
||||
activationType, outputSize * outputSize * nChannels
|
||||
@@ -46,7 +46,7 @@ float* MaxPooling2D::forward(const float* d_input) {
|
||||
);
|
||||
|
||||
Kernels::max_pooling<<<grid, block>>>(
|
||||
d_input, d_output, inputSize, nChannels, poolingSize, stride
|
||||
d_input, d_output, inputSize, outputSize, nChannels, poolingSize, stride
|
||||
);
|
||||
|
||||
return d_output;
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
find_package(GTest REQUIRED)
|
||||
include_directories(${GTEST_INCLUDE_DIRS})
|
||||
|
||||
file(GLOB_RECURSE TEST_SOURCES
|
||||
*.cu
|
||||
kernels/*.cu
|
||||
layers/*.cu
|
||||
)
|
||||
|
||||
add_executable(test_main
|
||||
EXCLUDE_FROM_ALL
|
||||
layers/test_activation.cu
|
||||
layers/test_concat.cu
|
||||
layers/test_conv2d.cu
|
||||
layers/test_dense.cu
|
||||
layers/test_input.cu
|
||||
kernels/test_activation_functions.cu
|
||||
kernels/test_matmul.cu
|
||||
${TEST_SOURCES}
|
||||
)
|
||||
|
||||
target_link_libraries(test_main ${GTEST_BOTH_LIBRARIES} CUDANet)
|
||||
|
||||
70
test/layers/test_max_pooling.cu
Normal file
70
test/layers/test_max_pooling.cu
Normal file
@@ -0,0 +1,70 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "max_pooling.cuh"
|
||||
|
||||
TEST(MaxPoolingLayerTest, MaxPoolForwardTest) {
|
||||
int inputSize = 4;
|
||||
int nChannels = 2;
|
||||
int poolingSize = 2;
|
||||
int stride = 2;
|
||||
|
||||
cudaError_t cudaStatus;
|
||||
|
||||
std::vector<float> input = {
|
||||
// clang-format off
|
||||
// Channel 0
|
||||
0.573f, 0.619f, 0.732f, 0.055f,
|
||||
0.243f, 0.316f, 0.573f, 0.619f,
|
||||
0.712f, 0.055f, 0.243f, 0.316f,
|
||||
0.573f, 0.619f, 0.742f, 0.055f,
|
||||
// Channel 1
|
||||
0.473f, 0.919f, 0.107f, 0.073f,
|
||||
0.073f, 0.362f, 0.973f, 0.059f,
|
||||
0.473f, 0.455f, 0.283f, 0.416f,
|
||||
0.532f, 0.819f, 0.732f, 0.850f
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
CUDANet::Layers::MaxPooling2D maxPoolingLayer(
|
||||
inputSize, nChannels, poolingSize, stride,
|
||||
CUDANet::Layers::ActivationType::NONE
|
||||
);
|
||||
|
||||
float *d_input;
|
||||
|
||||
cudaStatus = cudaMalloc(
|
||||
(void **)&d_input, sizeof(float) * inputSize * inputSize * nChannels
|
||||
);
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
cudaStatus = cudaMemcpy(
|
||||
d_input, input.data(),
|
||||
sizeof(float) * inputSize * inputSize * nChannels,
|
||||
cudaMemcpyHostToDevice
|
||||
);
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
float *d_output = maxPoolingLayer.forward(d_input);
|
||||
|
||||
int outputSize = maxPoolingLayer.getOutputSize();
|
||||
|
||||
std::vector<float> output(outputSize * outputSize * nChannels);
|
||||
cudaStatus = cudaMemcpy(
|
||||
output.data(), d_output,
|
||||
sizeof(float) * outputSize * outputSize * nChannels,
|
||||
cudaMemcpyDeviceToHost
|
||||
);
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
std::vector<float> expected = {0.619f, 0.732f, 0.712f, 0.742f, 0.919f, 0.973f, 0.819f, 0.85f};
|
||||
|
||||
for (int i = 0; i < output.size(); ++i) {
|
||||
EXPECT_FLOAT_EQ(expected[i], output[i]);
|
||||
}
|
||||
|
||||
cudaFree(d_input);
|
||||
cudaFree(d_output);
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
|
||||
def conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights):
|
||||
def _conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights):
|
||||
|
||||
conv2d = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
|
||||
conv2d.weight = torch.nn.Parameter(weights)
|
||||
@@ -11,7 +11,7 @@ def conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weig
|
||||
output = torch.flatten(output)
|
||||
return output
|
||||
|
||||
def print_cpp_vector(vector):
|
||||
def _print_cpp_vector(vector):
|
||||
print("std::vector<float> expected = {", end="")
|
||||
for i in range(len(vector)):
|
||||
if i != 0:
|
||||
@@ -20,7 +20,7 @@ def print_cpp_vector(vector):
|
||||
print("};")
|
||||
|
||||
|
||||
def gen_padded_test_result():
|
||||
def gen_convd_padded_test_result():
|
||||
|
||||
in_channels = 3
|
||||
out_channels = 2
|
||||
@@ -68,10 +68,10 @@ def gen_padded_test_result():
|
||||
0.011, 0.345, 0.678
|
||||
], dtype=torch.float).reshape(2, 3, 3, 3)
|
||||
|
||||
output = conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights)
|
||||
print_cpp_vector(output)
|
||||
output = _conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights)
|
||||
_print_cpp_vector(output)
|
||||
|
||||
def gen_strided_test_result():
|
||||
def gen_convd_strided_test_result():
|
||||
|
||||
in_channels = 2
|
||||
out_channels = 2
|
||||
@@ -106,8 +106,8 @@ def gen_strided_test_result():
|
||||
0.939, 0.891, 0.006
|
||||
], dtype=torch.float).reshape(2, 2, 3, 3)
|
||||
|
||||
output = conv2d(in_channels, out_channels, kernel_size, stride, padding, input, weights)
|
||||
print_cpp_vector(output)
|
||||
output = _conv2d(in_channels, out_channels, kernel_size, stride, padding, input, weights)
|
||||
_print_cpp_vector(output)
|
||||
|
||||
def gen_softmax_test_result():
|
||||
input = torch.tensor([
|
||||
@@ -115,14 +115,33 @@ def gen_softmax_test_result():
|
||||
])
|
||||
|
||||
output = torch.nn.Softmax(dim=0)(input)
|
||||
print_cpp_vector(output)
|
||||
_print_cpp_vector(output)
|
||||
|
||||
def gen_max_pool_test_result():
|
||||
input = torch.tensor([
|
||||
0.573, 0.619, 0.732, 0.055,
|
||||
0.243, 0.316, 0.573, 0.619,
|
||||
0.712, 0.055, 0.243, 0.316,
|
||||
0.573, 0.619, 0.742, 0.055,
|
||||
0.473, 0.919, 0.107, 0.073,
|
||||
0.073, 0.362, 0.973, 0.059,
|
||||
0.473, 0.455, 0.283, 0.416,
|
||||
0.532, 0.819, 0.732, 0.850
|
||||
]).reshape(1, 2, 4, 4)
|
||||
|
||||
output = torch.nn.MaxPool2d(kernel_size=2, stride=2)(input)
|
||||
output = torch.flatten(output)
|
||||
|
||||
_print_cpp_vector(output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Generating test results...")
|
||||
print("Padded convolution test:")
|
||||
gen_padded_test_result()
|
||||
print("Strided convolution test:")
|
||||
gen_strided_test_result()
|
||||
print("Softmax test:")
|
||||
gen_softmax_test_result()
|
||||
# print("Generating test results...")
|
||||
# print("Padded convolution test:")
|
||||
# gen_convd_padded_test_result()
|
||||
# print("Strided convolution test:")
|
||||
# gen_convd_strided_test_result()
|
||||
# print("Softmax test:")
|
||||
# gen_softmax_test_result()
|
||||
print("Max pool test:")
|
||||
gen_max_pool_test_result()
|
||||
Reference in New Issue
Block a user