diff --git a/include/layers/conv2d.cuh b/include/layers/conv2d.cuh index c3601ba..e30321e 100644 --- a/include/layers/conv2d.cuh +++ b/include/layers/conv2d.cuh @@ -11,16 +11,19 @@ namespace Layers { class Conv2d { public: Conv2d( - int inputSize, - int inputChannels, - int kernelSize, - int stride, - std::string padding, - int numFilters, - Activation activation + int inputSize, + int inputChannels, + int kernelSize, + int stride, + std::string padding, + int numFilters, + Activation activation ); ~Conv2d(); + // Outputs + int outputSize; + void forward(const float* d_input, float* d_output); private: @@ -34,15 +37,12 @@ class Conv2d { int paddingSize; int numFilters; - // Outputs - int outputSize; - // Kernels std::vector kernels; // Cuda - float* d_kernels; - float* d_padded; + float* d_kernels; + float* d_padded; // Kernels Activation activation; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 439e74a..f0ffefd 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -3,6 +3,7 @@ include_directories(${GTEST_INCLUDE_DIRS}) add_executable(test_main layers/test_dense.cu + layers/test_conv2d.cu kernels/test_activations.cu kernels/test_padding.cu ) diff --git a/test/layers/test_conv2d.cu b/test/layers/test_conv2d.cu new file mode 100644 index 0000000..9e620d2 --- /dev/null +++ b/test/layers/test_conv2d.cu @@ -0,0 +1,35 @@ +#include +#include + +#include + +#include "conv2d.cuh" + +TEST(Conv2dTest, ValidPadding) { + + int inputSize = 3; + int inputChannels = 1; + int kernelSize = 3; + int stride = 1; + std::string padding = "VALID"; + int numFilters = 1; + Activation activation = LINEAR; + + Layers::Conv2d conv2d( + inputSize, + inputChannels, + kernelSize, + stride, + padding, + numFilters, + activation + ); + + int outputSize = (inputSize - kernelSize) / stride + 1; + EXPECT_EQ(outputSize, conv2d.outputSize); + + std::vector input(inputSize * inputSize * inputChannels); + std::vector output(outputSize * outputSize * numFilters); + std::vector kernels(kernelSize * kernelSize * inputChannels * numFilters); + +}