mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 09:44:28 +00:00
Implement max pooling test
This commit is contained in:
@@ -11,7 +11,8 @@ file(GLOB_RECURSE LIBRARY_SOURCES
|
|||||||
src/*.cu
|
src/*.cu
|
||||||
src/utils/*.cu
|
src/utils/*.cu
|
||||||
src/kernels/*.cu
|
src/kernels/*.cu
|
||||||
src/layers/*.cu)
|
src/layers/*.cu
|
||||||
|
)
|
||||||
|
|
||||||
set(LIBRARY_SOURCES
|
set(LIBRARY_SOURCES
|
||||||
${LIBRARY_SOURCES}
|
${LIBRARY_SOURCES}
|
||||||
@@ -19,7 +20,6 @@ set(LIBRARY_SOURCES
|
|||||||
|
|
||||||
set(CMAKE_CUDA_ARCHITECTURES 75)
|
set(CMAKE_CUDA_ARCHITECTURES 75)
|
||||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
# set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -arch=sm_75)
|
|
||||||
|
|
||||||
# Build static library
|
# Build static library
|
||||||
add_library(${PROJECT_NAME} STATIC ${LIBRARY_SOURCES})
|
add_library(${PROJECT_NAME} STATIC ${LIBRARY_SOURCES})
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ __global__ void max_pooling(
|
|||||||
const float* __restrict__ d_input,
|
const float* __restrict__ d_input,
|
||||||
float* __restrict__ d_output,
|
float* __restrict__ d_output,
|
||||||
const int inputSize,
|
const int inputSize,
|
||||||
|
const int outputSize,
|
||||||
const int nChannels,
|
const int nChannels,
|
||||||
const int poolingSize,
|
const int poolingSize,
|
||||||
const int stride
|
const int stride
|
||||||
@@ -18,6 +19,7 @@ __global__ void avg_pooling(
|
|||||||
const float* __restrict__ d_input,
|
const float* __restrict__ d_input,
|
||||||
float* __restrict__ d_output,
|
float* __restrict__ d_output,
|
||||||
const int inputSize,
|
const int inputSize,
|
||||||
|
const int outputSize,
|
||||||
const int nChannels,
|
const int nChannels,
|
||||||
const int poolingSize,
|
const int poolingSize,
|
||||||
const int stride
|
const int stride
|
||||||
|
|||||||
@@ -3,8 +3,8 @@
|
|||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
#include "layer.cuh"
|
|
||||||
#include "activation.cuh"
|
#include "activation.cuh"
|
||||||
|
#include "layer.cuh"
|
||||||
|
|
||||||
namespace CUDANet::Layers {
|
namespace CUDANet::Layers {
|
||||||
|
|
||||||
@@ -21,6 +21,15 @@ class MaxPooling2D : public SequentialLayer {
|
|||||||
|
|
||||||
float* forward(const float* d_input);
|
float* forward(const float* d_input);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Get the output width (/ height) of the layer
|
||||||
|
*
|
||||||
|
* @return int
|
||||||
|
*/
|
||||||
|
int getOutputSize() {
|
||||||
|
return outputSize;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int inputSize;
|
int inputSize;
|
||||||
int nChannels;
|
int nChannels;
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ __global__ void Kernels::max_pooling(
|
|||||||
const float* __restrict__ d_input,
|
const float* __restrict__ d_input,
|
||||||
float* __restrict__ d_output,
|
float* __restrict__ d_output,
|
||||||
const int inputSize,
|
const int inputSize,
|
||||||
|
const int outputSize,
|
||||||
const int nChannels,
|
const int nChannels,
|
||||||
const int poolingSize,
|
const int poolingSize,
|
||||||
const int stride
|
const int stride
|
||||||
@@ -15,7 +16,7 @@ __global__ void Kernels::max_pooling(
|
|||||||
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
||||||
int c = blockDim.z * blockIdx.z + threadIdx.z;
|
int c = blockDim.z * blockIdx.z + threadIdx.z;
|
||||||
|
|
||||||
if (i >= inputSize || j >= inputSize || c >= nChannels) {
|
if (i >= outputSize || j >= outputSize || c >= nChannels) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -32,13 +33,14 @@ __global__ void Kernels::max_pooling(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
d_output[c * inputSize * inputSize + i * inputSize + j] = max;
|
d_output[c * outputSize * outputSize + i * outputSize + j] = max;
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Kernels::avg_pooling(
|
__global__ void Kernels::avg_pooling(
|
||||||
const float* __restrict__ d_input,
|
const float* __restrict__ d_input,
|
||||||
float* __restrict__ d_output,
|
float* __restrict__ d_output,
|
||||||
const int inputSize,
|
const int inputSize,
|
||||||
|
const int outputSize,
|
||||||
const int nChannels,
|
const int nChannels,
|
||||||
const int poolingSize,
|
const int poolingSize,
|
||||||
const int stride
|
const int stride
|
||||||
@@ -62,6 +64,6 @@ __global__ void Kernels::avg_pooling(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
d_output[c * inputSize * inputSize + i * inputSize + j] =
|
d_output[c * outputSize * outputSize + i * outputSize + j] =
|
||||||
sum / (poolingSize * poolingSize);
|
sum / (poolingSize * poolingSize);
|
||||||
}
|
}
|
||||||
@@ -43,7 +43,7 @@ float* AvgPooling2D::forward(const float* d_input) {
|
|||||||
);
|
);
|
||||||
|
|
||||||
Kernels::avg_pooling<<<grid, block>>>(
|
Kernels::avg_pooling<<<grid, block>>>(
|
||||||
d_input, d_output, inputSize, nChannels, poolingSize, stride
|
d_input, d_output, inputSize, outputSize, nChannels, poolingSize, stride
|
||||||
);
|
);
|
||||||
|
|
||||||
return d_output;
|
return d_output;
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ MaxPooling2D::MaxPooling2D(
|
|||||||
: inputSize(inputSize), nChannels(nChannels), poolingSize(poolingSize), stride(stride) {
|
: inputSize(inputSize), nChannels(nChannels), poolingSize(poolingSize), stride(stride) {
|
||||||
|
|
||||||
|
|
||||||
outputSize = (inputSize - poolingSize) / stride + 1;
|
outputSize = (inputSize - 1) / stride + 1;
|
||||||
|
|
||||||
activation = Activation(
|
activation = Activation(
|
||||||
activationType, outputSize * outputSize * nChannels
|
activationType, outputSize * outputSize * nChannels
|
||||||
@@ -46,7 +46,7 @@ float* MaxPooling2D::forward(const float* d_input) {
|
|||||||
);
|
);
|
||||||
|
|
||||||
Kernels::max_pooling<<<grid, block>>>(
|
Kernels::max_pooling<<<grid, block>>>(
|
||||||
d_input, d_output, inputSize, nChannels, poolingSize, stride
|
d_input, d_output, inputSize, outputSize, nChannels, poolingSize, stride
|
||||||
);
|
);
|
||||||
|
|
||||||
return d_output;
|
return d_output;
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
find_package(GTest REQUIRED)
|
find_package(GTest REQUIRED)
|
||||||
include_directories(${GTEST_INCLUDE_DIRS})
|
include_directories(${GTEST_INCLUDE_DIRS})
|
||||||
|
|
||||||
|
file(GLOB_RECURSE TEST_SOURCES
|
||||||
|
*.cu
|
||||||
|
kernels/*.cu
|
||||||
|
layers/*.cu
|
||||||
|
)
|
||||||
|
|
||||||
add_executable(test_main
|
add_executable(test_main
|
||||||
EXCLUDE_FROM_ALL
|
EXCLUDE_FROM_ALL
|
||||||
layers/test_activation.cu
|
${TEST_SOURCES}
|
||||||
layers/test_concat.cu
|
|
||||||
layers/test_conv2d.cu
|
|
||||||
layers/test_dense.cu
|
|
||||||
layers/test_input.cu
|
|
||||||
kernels/test_activation_functions.cu
|
|
||||||
kernels/test_matmul.cu
|
|
||||||
)
|
)
|
||||||
|
|
||||||
target_link_libraries(test_main ${GTEST_BOTH_LIBRARIES} CUDANet)
|
target_link_libraries(test_main ${GTEST_BOTH_LIBRARIES} CUDANet)
|
||||||
|
|||||||
70
test/layers/test_max_pooling.cu
Normal file
70
test/layers/test_max_pooling.cu
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "max_pooling.cuh"
|
||||||
|
|
||||||
|
TEST(MaxPoolingLayerTest, MaxPoolForwardTest) {
|
||||||
|
int inputSize = 4;
|
||||||
|
int nChannels = 2;
|
||||||
|
int poolingSize = 2;
|
||||||
|
int stride = 2;
|
||||||
|
|
||||||
|
cudaError_t cudaStatus;
|
||||||
|
|
||||||
|
std::vector<float> input = {
|
||||||
|
// clang-format off
|
||||||
|
// Channel 0
|
||||||
|
0.573f, 0.619f, 0.732f, 0.055f,
|
||||||
|
0.243f, 0.316f, 0.573f, 0.619f,
|
||||||
|
0.712f, 0.055f, 0.243f, 0.316f,
|
||||||
|
0.573f, 0.619f, 0.742f, 0.055f,
|
||||||
|
// Channel 1
|
||||||
|
0.473f, 0.919f, 0.107f, 0.073f,
|
||||||
|
0.073f, 0.362f, 0.973f, 0.059f,
|
||||||
|
0.473f, 0.455f, 0.283f, 0.416f,
|
||||||
|
0.532f, 0.819f, 0.732f, 0.850f
|
||||||
|
// clang-format on
|
||||||
|
};
|
||||||
|
|
||||||
|
CUDANet::Layers::MaxPooling2D maxPoolingLayer(
|
||||||
|
inputSize, nChannels, poolingSize, stride,
|
||||||
|
CUDANet::Layers::ActivationType::NONE
|
||||||
|
);
|
||||||
|
|
||||||
|
float *d_input;
|
||||||
|
|
||||||
|
cudaStatus = cudaMalloc(
|
||||||
|
(void **)&d_input, sizeof(float) * inputSize * inputSize * nChannels
|
||||||
|
);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
cudaStatus = cudaMemcpy(
|
||||||
|
d_input, input.data(),
|
||||||
|
sizeof(float) * inputSize * inputSize * nChannels,
|
||||||
|
cudaMemcpyHostToDevice
|
||||||
|
);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
float *d_output = maxPoolingLayer.forward(d_input);
|
||||||
|
|
||||||
|
int outputSize = maxPoolingLayer.getOutputSize();
|
||||||
|
|
||||||
|
std::vector<float> output(outputSize * outputSize * nChannels);
|
||||||
|
cudaStatus = cudaMemcpy(
|
||||||
|
output.data(), d_output,
|
||||||
|
sizeof(float) * outputSize * outputSize * nChannels,
|
||||||
|
cudaMemcpyDeviceToHost
|
||||||
|
);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
std::vector<float> expected = {0.619f, 0.732f, 0.712f, 0.742f, 0.919f, 0.973f, 0.819f, 0.85f};
|
||||||
|
|
||||||
|
for (int i = 0; i < output.size(); ++i) {
|
||||||
|
EXPECT_FLOAT_EQ(expected[i], output[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaFree(d_input);
|
||||||
|
cudaFree(d_output);
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
def conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights):
|
def _conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights):
|
||||||
|
|
||||||
conv2d = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
|
conv2d = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
|
||||||
conv2d.weight = torch.nn.Parameter(weights)
|
conv2d.weight = torch.nn.Parameter(weights)
|
||||||
@@ -11,7 +11,7 @@ def conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weig
|
|||||||
output = torch.flatten(output)
|
output = torch.flatten(output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def print_cpp_vector(vector):
|
def _print_cpp_vector(vector):
|
||||||
print("std::vector<float> expected = {", end="")
|
print("std::vector<float> expected = {", end="")
|
||||||
for i in range(len(vector)):
|
for i in range(len(vector)):
|
||||||
if i != 0:
|
if i != 0:
|
||||||
@@ -20,7 +20,7 @@ def print_cpp_vector(vector):
|
|||||||
print("};")
|
print("};")
|
||||||
|
|
||||||
|
|
||||||
def gen_padded_test_result():
|
def gen_convd_padded_test_result():
|
||||||
|
|
||||||
in_channels = 3
|
in_channels = 3
|
||||||
out_channels = 2
|
out_channels = 2
|
||||||
@@ -68,10 +68,10 @@ def gen_padded_test_result():
|
|||||||
0.011, 0.345, 0.678
|
0.011, 0.345, 0.678
|
||||||
], dtype=torch.float).reshape(2, 3, 3, 3)
|
], dtype=torch.float).reshape(2, 3, 3, 3)
|
||||||
|
|
||||||
output = conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights)
|
output = _conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights)
|
||||||
print_cpp_vector(output)
|
_print_cpp_vector(output)
|
||||||
|
|
||||||
def gen_strided_test_result():
|
def gen_convd_strided_test_result():
|
||||||
|
|
||||||
in_channels = 2
|
in_channels = 2
|
||||||
out_channels = 2
|
out_channels = 2
|
||||||
@@ -106,8 +106,8 @@ def gen_strided_test_result():
|
|||||||
0.939, 0.891, 0.006
|
0.939, 0.891, 0.006
|
||||||
], dtype=torch.float).reshape(2, 2, 3, 3)
|
], dtype=torch.float).reshape(2, 2, 3, 3)
|
||||||
|
|
||||||
output = conv2d(in_channels, out_channels, kernel_size, stride, padding, input, weights)
|
output = _conv2d(in_channels, out_channels, kernel_size, stride, padding, input, weights)
|
||||||
print_cpp_vector(output)
|
_print_cpp_vector(output)
|
||||||
|
|
||||||
def gen_softmax_test_result():
|
def gen_softmax_test_result():
|
||||||
input = torch.tensor([
|
input = torch.tensor([
|
||||||
@@ -115,14 +115,33 @@ def gen_softmax_test_result():
|
|||||||
])
|
])
|
||||||
|
|
||||||
output = torch.nn.Softmax(dim=0)(input)
|
output = torch.nn.Softmax(dim=0)(input)
|
||||||
print_cpp_vector(output)
|
_print_cpp_vector(output)
|
||||||
|
|
||||||
|
def gen_max_pool_test_result():
|
||||||
|
input = torch.tensor([
|
||||||
|
0.573, 0.619, 0.732, 0.055,
|
||||||
|
0.243, 0.316, 0.573, 0.619,
|
||||||
|
0.712, 0.055, 0.243, 0.316,
|
||||||
|
0.573, 0.619, 0.742, 0.055,
|
||||||
|
0.473, 0.919, 0.107, 0.073,
|
||||||
|
0.073, 0.362, 0.973, 0.059,
|
||||||
|
0.473, 0.455, 0.283, 0.416,
|
||||||
|
0.532, 0.819, 0.732, 0.850
|
||||||
|
]).reshape(1, 2, 4, 4)
|
||||||
|
|
||||||
|
output = torch.nn.MaxPool2d(kernel_size=2, stride=2)(input)
|
||||||
|
output = torch.flatten(output)
|
||||||
|
|
||||||
|
_print_cpp_vector(output)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("Generating test results...")
|
# print("Generating test results...")
|
||||||
print("Padded convolution test:")
|
# print("Padded convolution test:")
|
||||||
gen_padded_test_result()
|
# gen_convd_padded_test_result()
|
||||||
print("Strided convolution test:")
|
# print("Strided convolution test:")
|
||||||
gen_strided_test_result()
|
# gen_convd_strided_test_result()
|
||||||
print("Softmax test:")
|
# print("Softmax test:")
|
||||||
gen_softmax_test_result()
|
# gen_softmax_test_result()
|
||||||
|
print("Max pool test:")
|
||||||
|
gen_max_pool_test_result()
|
||||||
Reference in New Issue
Block a user