mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Change forward function to return output pointer
This commit is contained in:
@@ -17,8 +17,7 @@ class Conv2dTest : public ::testing::Test {
|
||||
Layers::Activation activation,
|
||||
std::vector<float>& input,
|
||||
float* kernels,
|
||||
float*& d_input,
|
||||
float*& d_output
|
||||
float*& d_input
|
||||
) {
|
||||
// Create Conv2d layer
|
||||
Layers::Conv2d conv2d(
|
||||
@@ -35,12 +34,6 @@ class Conv2dTest : public ::testing::Test {
|
||||
);
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
cudaStatus = cudaMalloc(
|
||||
(void**)&d_output,
|
||||
sizeof(float) * conv2d.outputSize * conv2d.outputSize * numFilters
|
||||
);
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
// // Copy input to device
|
||||
cudaStatus = cudaMemcpy(
|
||||
d_input, input.data(), sizeof(float) * input.size(),
|
||||
@@ -51,10 +44,9 @@ class Conv2dTest : public ::testing::Test {
|
||||
return conv2d;
|
||||
}
|
||||
|
||||
void commonTestTeardown(float* d_input, float* d_output) {
|
||||
void commonTestTeardown(float* d_input) {
|
||||
// Free device memory
|
||||
cudaFree(d_input);
|
||||
cudaFree(d_output);
|
||||
}
|
||||
|
||||
cudaError_t cudaStatus;
|
||||
@@ -84,13 +76,13 @@ TEST_F(Conv2dTest, SimpleTest) {
|
||||
|
||||
Layers::Conv2d conv2d = commonTestSetup(
|
||||
inputSize, inputChannels, kernelSize, stride, padding, numFilters,
|
||||
activation, input, kernels.data(), d_input, d_output
|
||||
activation, input, kernels.data(), d_input
|
||||
);
|
||||
|
||||
int outputSize = (inputSize - kernelSize) / stride + 1;
|
||||
EXPECT_EQ(outputSize, conv2d.outputSize);
|
||||
|
||||
conv2d.forward(d_input, d_output);
|
||||
d_output = conv2d.forward(d_input);
|
||||
|
||||
std::vector<float> expected = {44.0f, 54.0f, 64.0f, 84.0f, 94.0f,
|
||||
104.0f, 124.0f, 134.0f, 144.0f};
|
||||
@@ -106,7 +98,7 @@ TEST_F(Conv2dTest, SimpleTest) {
|
||||
EXPECT_FLOAT_EQ(expected[i], output[i]);
|
||||
}
|
||||
|
||||
commonTestTeardown(d_input, d_output);
|
||||
commonTestTeardown(d_input);
|
||||
}
|
||||
|
||||
TEST_F(Conv2dTest, PaddedTest) {
|
||||
@@ -173,12 +165,12 @@ TEST_F(Conv2dTest, PaddedTest) {
|
||||
|
||||
Layers::Conv2d conv2d = commonTestSetup(
|
||||
inputSize, inputChannels, kernelSize, stride, padding, numFilters,
|
||||
activation, input, kernels.data(), d_input, d_output
|
||||
activation, input, kernels.data(), d_input
|
||||
);
|
||||
|
||||
EXPECT_EQ(inputSize, conv2d.outputSize);
|
||||
|
||||
conv2d.forward(d_input, d_output);
|
||||
d_output = conv2d.forward(d_input);
|
||||
|
||||
std::vector<float> output(
|
||||
conv2d.outputSize * conv2d.outputSize * numFilters
|
||||
@@ -192,23 +184,21 @@ TEST_F(Conv2dTest, PaddedTest) {
|
||||
// Generated by tools/generate_conv2d_test.py
|
||||
std::vector<float> expected = {
|
||||
// Channel 1
|
||||
2.29426f, 3.89173f, 4.17634f, 3.25501f, 2.07618f,
|
||||
5.41483f, 7.09971f, 6.39811f, 5.71432f, 3.10928f,
|
||||
5.12973f, 6.29638f, 5.26962f, 5.21997f, 3.05852f,
|
||||
6.17517f, 7.19311f, 6.69771f, 6.2142f, 4.03242f,
|
||||
3.3792f, 4.36444f, 4.396f, 4.69905f, 3.62061f,
|
||||
2.29426f, 3.89173f, 4.17634f, 3.25501f, 2.07618f, 5.41483f, 7.09971f,
|
||||
6.39811f, 5.71432f, 3.10928f, 5.12973f, 6.29638f, 5.26962f, 5.21997f,
|
||||
3.05852f, 6.17517f, 7.19311f, 6.69771f, 6.2142f, 4.03242f, 3.3792f,
|
||||
4.36444f, 4.396f, 4.69905f, 3.62061f,
|
||||
// Channel 2
|
||||
2.87914f, 3.71743f, 3.51854f, 2.98413f, 1.46579f,
|
||||
4.94951f, 6.18983f, 4.98187f, 4.38372f, 3.35386f,
|
||||
5.0364f, 5.3756f, 4.05993f, 4.89299f, 2.78625f,
|
||||
5.33763f, 5.80899f, 5.89785f, 5.51095f, 3.74287f,
|
||||
2.64053f, 4.05895f, 3.96482f, 4.30177f, 1.94269f
|
||||
2.87914f, 3.71743f, 3.51854f, 2.98413f, 1.46579f, 4.94951f, 6.18983f,
|
||||
4.98187f, 4.38372f, 3.35386f, 5.0364f, 5.3756f, 4.05993f, 4.89299f,
|
||||
2.78625f, 5.33763f, 5.80899f, 5.89785f, 5.51095f, 3.74287f, 2.64053f,
|
||||
4.05895f, 3.96482f, 4.30177f, 1.94269f
|
||||
};
|
||||
for (int i = 0; i < output.size(); i++) {
|
||||
EXPECT_NEAR(output[i], expected[i], 0.0001f);
|
||||
}
|
||||
|
||||
commonTestTeardown(d_input, d_output);
|
||||
commonTestTeardown(d_input);
|
||||
}
|
||||
|
||||
TEST_F(Conv2dTest, StridedPaddedConvolution) {
|
||||
@@ -260,12 +250,12 @@ TEST_F(Conv2dTest, StridedPaddedConvolution) {
|
||||
|
||||
Layers::Conv2d conv2d = commonTestSetup(
|
||||
inputSize, inputChannels, kernelSize, stride, padding, numFilters,
|
||||
activation, input, kernels.data(), d_input, d_output
|
||||
activation, input, kernels.data(), d_input
|
||||
);
|
||||
|
||||
EXPECT_EQ(inputSize, conv2d.outputSize);
|
||||
|
||||
conv2d.forward(d_input, d_output);
|
||||
d_output = conv2d.forward(d_input);
|
||||
|
||||
std::vector<float> output(
|
||||
conv2d.outputSize * conv2d.outputSize * numFilters
|
||||
@@ -279,22 +269,18 @@ TEST_F(Conv2dTest, StridedPaddedConvolution) {
|
||||
// Generated by tools/generate_conv2d_test.py
|
||||
std::vector<float> expected = {
|
||||
// Channel 1
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||
0.0f, 1.59803f, 2.84444f, 1.6201f, 0.0f,
|
||||
0.0f, 2.38937f, 3.80762f, 3.39679f, 0.0f,
|
||||
0.0f, 1.13102f, 2.33335f, 1.98488f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.59803f, 2.84444f, 1.6201f, 0.0f,
|
||||
0.0f, 2.38937f, 3.80762f, 3.39679f, 0.0f, 0.0f, 1.13102f, 2.33335f,
|
||||
1.98488f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||
// Channel 2
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||
0.0f, 2.57732f, 3.55543f, 2.24675f, 0.0f,
|
||||
0.0f, 3.36842f, 3.41373f, 3.14804f, 0.0f,
|
||||
0.0f, 1.17963f, 2.55005f, 1.63218f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 2.57732f, 3.55543f, 2.24675f, 0.0f,
|
||||
0.0f, 3.36842f, 3.41373f, 3.14804f, 0.0f, 0.0f, 1.17963f, 2.55005f,
|
||||
1.63218f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f
|
||||
};
|
||||
|
||||
for (int i = 0; i < output.size(); i++) {
|
||||
EXPECT_NEAR(output[i], expected[i], 0.0001f);
|
||||
}
|
||||
|
||||
commonTestTeardown(d_input, d_output);
|
||||
commonTestTeardown(d_input);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user