mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Implement avg pool test
This commit is contained in:
@@ -19,6 +19,15 @@ class AvgPooling2D : public SequentialLayer {
|
|||||||
|
|
||||||
float* forward(const float* d_input);
|
float* forward(const float* d_input);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Get the output width (/ height) of the layer
|
||||||
|
*
|
||||||
|
* @return int
|
||||||
|
*/
|
||||||
|
int getOutputSize() {
|
||||||
|
return outputSize;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int inputSize;
|
int inputSize;
|
||||||
int nChannels;
|
int nChannels;
|
||||||
@@ -26,7 +35,6 @@ class AvgPooling2D : public SequentialLayer {
|
|||||||
int stride;
|
int stride;
|
||||||
|
|
||||||
int outputSize;
|
int outputSize;
|
||||||
int gridSize;
|
|
||||||
|
|
||||||
float* d_output;
|
float* d_output;
|
||||||
|
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ class MaxPooling2D : public SequentialLayer {
|
|||||||
int stride;
|
int stride;
|
||||||
|
|
||||||
int outputSize;
|
int outputSize;
|
||||||
int gridSize;
|
|
||||||
|
|
||||||
float* d_output;
|
float* d_output;
|
||||||
|
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ __global__ void Kernels::avg_pooling(
|
|||||||
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
||||||
int c = blockDim.z * blockIdx.z + threadIdx.z;
|
int c = blockDim.z * blockIdx.z + threadIdx.z;
|
||||||
|
|
||||||
if (i >= inputSize || j >= inputSize || c >= nChannels) {
|
if (i >= outputSize || j >= outputSize || c >= outputSize) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -24,9 +24,6 @@ AvgPooling2D::AvgPooling2D(
|
|||||||
CUDA_CHECK(cudaMalloc(
|
CUDA_CHECK(cudaMalloc(
|
||||||
(void**)&d_output, sizeof(float) * outputSize * outputSize * nChannels
|
(void**)&d_output, sizeof(float) * outputSize * outputSize * nChannels
|
||||||
));
|
));
|
||||||
|
|
||||||
gridSize =
|
|
||||||
(outputSize * outputSize * nChannels + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
AvgPooling2D::~AvgPooling2D() {
|
AvgPooling2D::~AvgPooling2D() {
|
||||||
|
|||||||
@@ -25,9 +25,6 @@ MaxPooling2D::MaxPooling2D(
|
|||||||
CUDA_CHECK(cudaMalloc(
|
CUDA_CHECK(cudaMalloc(
|
||||||
(void**)&d_output, sizeof(float) * outputSize * outputSize * nChannels
|
(void**)&d_output, sizeof(float) * outputSize * outputSize * nChannels
|
||||||
));
|
));
|
||||||
|
|
||||||
gridSize = (outputSize * outputSize * nChannels + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
70
test/layers/test_avg_pooling.cu
Normal file
70
test/layers/test_avg_pooling.cu
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "avg_pooling.cuh"
|
||||||
|
|
||||||
|
TEST(AvgPoolingLayerTest, AvgPoolForwardTest) {
|
||||||
|
int inputSize = 4;
|
||||||
|
int nChannels = 2;
|
||||||
|
int poolingSize = 2;
|
||||||
|
int stride = 2;
|
||||||
|
|
||||||
|
cudaError_t cudaStatus;
|
||||||
|
|
||||||
|
std::vector<float> input = {
|
||||||
|
// clang-format off
|
||||||
|
// Channel 0
|
||||||
|
0.573f, 0.619f, 0.732f, 0.055f,
|
||||||
|
0.243f, 0.316f, 0.573f, 0.619f,
|
||||||
|
0.712f, 0.055f, 0.243f, 0.316f,
|
||||||
|
0.573f, 0.619f, 0.742f, 0.055f,
|
||||||
|
// Channel 1
|
||||||
|
0.473f, 0.919f, 0.107f, 0.073f,
|
||||||
|
0.073f, 0.362f, 0.973f, 0.059f,
|
||||||
|
0.473f, 0.455f, 0.283f, 0.416f,
|
||||||
|
0.532f, 0.819f, 0.732f, 0.850f
|
||||||
|
// clang-format on
|
||||||
|
};
|
||||||
|
|
||||||
|
CUDANet::Layers::AvgPooling2D avgPoolingLayer(
|
||||||
|
inputSize, nChannels, poolingSize, stride,
|
||||||
|
CUDANet::Layers::ActivationType::NONE
|
||||||
|
);
|
||||||
|
|
||||||
|
float *d_input;
|
||||||
|
|
||||||
|
cudaStatus = cudaMalloc(
|
||||||
|
(void **)&d_input, sizeof(float) * inputSize * inputSize * nChannels
|
||||||
|
);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
cudaStatus = cudaMemcpy(
|
||||||
|
d_input, input.data(),
|
||||||
|
sizeof(float) * inputSize * inputSize * nChannels,
|
||||||
|
cudaMemcpyHostToDevice
|
||||||
|
);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
float *d_output = avgPoolingLayer.forward(d_input);
|
||||||
|
|
||||||
|
int outputSize = avgPoolingLayer.getOutputSize();
|
||||||
|
|
||||||
|
std::vector<float> output(outputSize * outputSize * nChannels);
|
||||||
|
cudaStatus = cudaMemcpy(
|
||||||
|
output.data(), d_output,
|
||||||
|
sizeof(float) * outputSize * outputSize * nChannels,
|
||||||
|
cudaMemcpyDeviceToHost
|
||||||
|
);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
std::vector<float> expected = {0.43775f, 0.49475f, 0.48975f, 0.339f, 0.45675f, 0.303f, 0.56975f, 0.57025f};
|
||||||
|
|
||||||
|
for (int i = 0; i < output.size(); ++i) {
|
||||||
|
EXPECT_NEAR(expected[i], output[i], 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaFree(d_input);
|
||||||
|
cudaFree(d_output);
|
||||||
|
}
|
||||||
@@ -1,8 +1,20 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
def _conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights):
|
|
||||||
|
|
||||||
conv2d = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
|
def _conv2d(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
inputs,
|
||||||
|
weights):
|
||||||
|
|
||||||
|
conv2d = torch.nn.Conv2d(in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
bias=False)
|
||||||
conv2d.weight = torch.nn.Parameter(weights)
|
conv2d.weight = torch.nn.Parameter(weights)
|
||||||
|
|
||||||
output = conv2d(inputs)
|
output = conv2d(inputs)
|
||||||
@@ -11,6 +23,7 @@ def _conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, wei
|
|||||||
output = torch.flatten(output)
|
output = torch.flatten(output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _print_cpp_vector(vector):
|
def _print_cpp_vector(vector):
|
||||||
print("std::vector<float> expected = {", end="")
|
print("std::vector<float> expected = {", end="")
|
||||||
for i in range(len(vector)):
|
for i in range(len(vector)):
|
||||||
@@ -20,6 +33,19 @@ def _print_cpp_vector(vector):
|
|||||||
print("};")
|
print("};")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_pool_input():
|
||||||
|
return torch.tensor([
|
||||||
|
0.573, 0.619, 0.732, 0.055,
|
||||||
|
0.243, 0.316, 0.573, 0.619,
|
||||||
|
0.712, 0.055, 0.243, 0.316,
|
||||||
|
0.573, 0.619, 0.742, 0.055,
|
||||||
|
0.473, 0.919, 0.107, 0.073,
|
||||||
|
0.073, 0.362, 0.973, 0.059,
|
||||||
|
0.473, 0.455, 0.283, 0.416,
|
||||||
|
0.532, 0.819, 0.732, 0.850
|
||||||
|
]).reshape(1, 2, 4, 4)
|
||||||
|
|
||||||
|
|
||||||
def gen_convd_padded_test_result():
|
def gen_convd_padded_test_result():
|
||||||
|
|
||||||
in_channels = 3
|
in_channels = 3
|
||||||
@@ -68,9 +94,16 @@ def gen_convd_padded_test_result():
|
|||||||
0.011, 0.345, 0.678
|
0.011, 0.345, 0.678
|
||||||
], dtype=torch.float).reshape(2, 3, 3, 3)
|
], dtype=torch.float).reshape(2, 3, 3, 3)
|
||||||
|
|
||||||
output = _conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights)
|
output = _conv2d(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
inputs,
|
||||||
|
weights)
|
||||||
_print_cpp_vector(output)
|
_print_cpp_vector(output)
|
||||||
|
|
||||||
|
|
||||||
def gen_convd_strided_test_result():
|
def gen_convd_strided_test_result():
|
||||||
|
|
||||||
in_channels = 2
|
in_channels = 2
|
||||||
@@ -78,7 +111,7 @@ def gen_convd_strided_test_result():
|
|||||||
kernel_size = 3
|
kernel_size = 3
|
||||||
stride = 2
|
stride = 2
|
||||||
padding = 3
|
padding = 3
|
||||||
|
|
||||||
input = torch.tensor([
|
input = torch.tensor([
|
||||||
0.946, 0.879, 0.382, 0.542, 0.453,
|
0.946, 0.879, 0.382, 0.542, 0.453,
|
||||||
0.128, 0.860, 0.778, 0.049, 0.974,
|
0.128, 0.860, 0.778, 0.049, 0.974,
|
||||||
@@ -106,9 +139,16 @@ def gen_convd_strided_test_result():
|
|||||||
0.939, 0.891, 0.006
|
0.939, 0.891, 0.006
|
||||||
], dtype=torch.float).reshape(2, 2, 3, 3)
|
], dtype=torch.float).reshape(2, 2, 3, 3)
|
||||||
|
|
||||||
output = _conv2d(in_channels, out_channels, kernel_size, stride, padding, input, weights)
|
output = _conv2d(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
input,
|
||||||
|
weights)
|
||||||
_print_cpp_vector(output)
|
_print_cpp_vector(output)
|
||||||
|
|
||||||
|
|
||||||
def gen_softmax_test_result():
|
def gen_softmax_test_result():
|
||||||
input = torch.tensor([
|
input = torch.tensor([
|
||||||
0.573, 0.619, 0.732, 0.055, 0.243
|
0.573, 0.619, 0.732, 0.055, 0.243
|
||||||
@@ -117,17 +157,9 @@ def gen_softmax_test_result():
|
|||||||
output = torch.nn.Softmax(dim=0)(input)
|
output = torch.nn.Softmax(dim=0)(input)
|
||||||
_print_cpp_vector(output)
|
_print_cpp_vector(output)
|
||||||
|
|
||||||
|
|
||||||
def gen_max_pool_test_result():
|
def gen_max_pool_test_result():
|
||||||
input = torch.tensor([
|
input = _get_pool_input()
|
||||||
0.573, 0.619, 0.732, 0.055,
|
|
||||||
0.243, 0.316, 0.573, 0.619,
|
|
||||||
0.712, 0.055, 0.243, 0.316,
|
|
||||||
0.573, 0.619, 0.742, 0.055,
|
|
||||||
0.473, 0.919, 0.107, 0.073,
|
|
||||||
0.073, 0.362, 0.973, 0.059,
|
|
||||||
0.473, 0.455, 0.283, 0.416,
|
|
||||||
0.532, 0.819, 0.732, 0.850
|
|
||||||
]).reshape(1, 2, 4, 4)
|
|
||||||
|
|
||||||
output = torch.nn.MaxPool2d(kernel_size=2, stride=2)(input)
|
output = torch.nn.MaxPool2d(kernel_size=2, stride=2)(input)
|
||||||
output = torch.flatten(output)
|
output = torch.flatten(output)
|
||||||
@@ -135,13 +167,25 @@ def gen_max_pool_test_result():
|
|||||||
_print_cpp_vector(output)
|
_print_cpp_vector(output)
|
||||||
|
|
||||||
|
|
||||||
|
def gen_avg_pool_test_result():
|
||||||
|
|
||||||
|
input = _get_pool_input()
|
||||||
|
|
||||||
|
output = torch.nn.AvgPool2d(kernel_size=2, stride=2)(input)
|
||||||
|
output = torch.flatten(output)
|
||||||
|
|
||||||
|
_print_cpp_vector(output)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# print("Generating test results...")
|
print("Generating test results...")
|
||||||
# print("Padded convolution test:")
|
print("Padded convolution test:")
|
||||||
# gen_convd_padded_test_result()
|
gen_convd_padded_test_result()
|
||||||
# print("Strided convolution test:")
|
print("Strided convolution test:")
|
||||||
# gen_convd_strided_test_result()
|
gen_convd_strided_test_result()
|
||||||
# print("Softmax test:")
|
print("Softmax test:")
|
||||||
# gen_softmax_test_result()
|
gen_softmax_test_result()
|
||||||
print("Max pool test:")
|
print("Max pool test:")
|
||||||
gen_max_pool_test_result()
|
gen_max_pool_test_result()
|
||||||
|
print("Avg pool test:")
|
||||||
|
gen_avg_pool_test_result()
|
||||||
|
|||||||
Reference in New Issue
Block a user