From 42d646750b8fd054e3c331d3713a4571b0156886 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sun, 17 Mar 2024 18:37:15 +0100 Subject: [PATCH] Abstract activation and implement softmax --- CMakeLists.txt | 3 +- README.md | 5 +- include/kernels/activation_functions.cuh | 74 +++++++++++++++++ include/kernels/activations.cuh | 28 ------- include/kernels/matmul.cuh | 13 --- include/layers/activation.cuh | 55 +++++++++++++ include/layers/conv2d.cuh | 36 +++++---- include/layers/dense.cuh | 4 +- include/layers/ilayer.cuh | 10 --- src/kernels/activation_functions.cu | 79 ++++++++++++++++++ src/kernels/activations.cu | 29 ------- src/kernels/matmul.cu | 12 --- src/layers/activation.cu | 60 ++++++++++++++ src/layers/conv2d.cu | 40 ++++----- src/layers/dense.cu | 25 ++---- test/CMakeLists.txt | 3 +- ...ations.cu => test_activation_functions.cu} | 2 +- test/layers/test_conv2d.cu | 81 ++++++++++--------- test/layers/test_dense.cu | 16 ++-- 19 files changed, 370 insertions(+), 205 deletions(-) create mode 100644 include/kernels/activation_functions.cuh delete mode 100644 include/kernels/activations.cuh create mode 100644 include/layers/activation.cuh create mode 100644 src/kernels/activation_functions.cu delete mode 100644 src/kernels/activations.cu create mode 100644 src/layers/activation.cu rename test/kernels/{test_activations.cu => test_activation_functions.cu} (96%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1d3014d..f5ea2ad 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,12 +9,13 @@ include_directories(${CUDAToolkit_INCLUDE_DIRS}) set(LIBRARY_SOURCES src/utils/cuda_helper.cu - src/kernels/activations.cu + src/kernels/activation_functions.cu src/kernels/convolution.cu src/kernels/matmul.cu src/layers/dense.cu src/layers/conv2d.cu src/layers/input.cu + src/layers/activation.cu ) set(CMAKE_CUDA_ARCHITECTURES 75) diff --git a/README.md b/README.md index beff362..fac3299 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ Convolutional Neural Network inference library running on CUDA. - [CUDA](https://developer.nvidia.com/cuda-downloads) - [Google Test](https://github.com/google/googletest) (for testing only) -**build and test** +**build** ```sh mkdir build @@ -33,8 +33,9 @@ cmake -S .. make ``` -Run tests +**build and run tests** ```sh +make test_main ./test/test_main ``` \ No newline at end of file diff --git a/include/kernels/activation_functions.cuh b/include/kernels/activation_functions.cuh new file mode 100644 index 0000000..ce85e0c --- /dev/null +++ b/include/kernels/activation_functions.cuh @@ -0,0 +1,74 @@ +#ifndef CUDANET_ACTIVATION_FUNCTIONS_H +#define CUDANET_ACTIVATION_FUNCTIONS_H + +namespace CUDANet::Kernels { + +/** + * @brief Sigmoid activation function kernel + * + * @param src Pointer to the source array + * @param dst Pointer to the destination array + * @param len Length of the arrays + */ +__global__ void sigmoid( + const float* __restrict__ src, + float* __restrict__ dst, + const unsigned int len +); + +/** + * @brief Relu activation function kernel + * + * @param src Pointer to the source array + * @param dst Pointer to the destination array + * @param len Length of the arrays + */ +__global__ void relu( + const float* __restrict__ src, + float* __restrict__ dst, + const unsigned int len +); + +/** + * @brief Softmax activation exponentiation kernel + * + * @param src Pointer to the source array + * @param dst Pointer to the destination array + * @param len Length of the arrays + */ +__global__ void softmax_exp( + const float* __restrict__ src, + float* __restrict__ dst, + const unsigned int len +); + +/** + * @brief + * + * @param d_vector Device pointer to vector + * @param d_output Device pointer to output vector + * @param w Length of the vector + */ +__global__ void softmax_sum( + const float* __restrict__ d_vector, + float* __restrict__ d_output, + const unsigned int w +); + +/** + * @brief Softmax activation function kernel + * + * @param src Pointer to the source array + * @param dst Pointer to the destination array + * @param len Length of the arrays + */ +__global__ void softmax_div( + const float* __restrict__ src, + float* __restrict__ dst, + const float* __restrict__ sum, + const unsigned int len +); + +} // namespace CUDANet::Kernels + +#endif // CUDANET_ACTIVATION_FUNCTIONS_H \ No newline at end of file diff --git a/include/kernels/activations.cuh b/include/kernels/activations.cuh deleted file mode 100644 index 5e85edf..0000000 --- a/include/kernels/activations.cuh +++ /dev/null @@ -1,28 +0,0 @@ -#ifndef CUDANET_ACTIVATIONS_H -#define CUDANET_ACTIVATIONS_H - -namespace CUDANet::Kernels { - -/** - * @brief Sigmoid activation function kernel - * - * @param src Pointer to the source array - * @param dst Pointer to the destination array - * @param len Length of the arrays - */ -__global__ void -sigmoid(const float* __restrict__ src, float* __restrict__ dst, int len); - -/** - * @brief Relu activation function kernel - * - * @param src Pointer to the source array - * @param dst Pointer to the destination array - * @param len Length of the arrays - */ -__global__ void -relu(const float* __restrict__ src, float* __restrict__ dst, int len); - -} // namespace CUDANet::Kernels - -#endif // CUDANET_ACTIVATIONS_H \ No newline at end of file diff --git a/include/kernels/matmul.cuh b/include/kernels/matmul.cuh index 09de250..3d067b8 100644 --- a/include/kernels/matmul.cuh +++ b/include/kernels/matmul.cuh @@ -35,19 +35,6 @@ __global__ void vec_vec_add( const unsigned int w ); -/** - * @brief - * - * @param d_vector Device pointer to vector - * @param d_output Device pointer to output vector - * @param w Length of the vector - */ -__global__ void reduce_sum( - const float* __restrict__ d_vector, - float* __restrict__ d_output, - const unsigned int w -); - } // namespace CUDANet::Kernels #endif // CUDANET_MATMUL_H \ No newline at end of file diff --git a/include/layers/activation.cuh b/include/layers/activation.cuh new file mode 100644 index 0000000..3a407e6 --- /dev/null +++ b/include/layers/activation.cuh @@ -0,0 +1,55 @@ +#ifndef CUDANET_ACTIVATION_H +#define CUDANET_ACTIVATION_H + +namespace CUDANet::Layers { + +/** + * @brief Activation functions + * + * SIGMOID: Sigmoid + * RELU: Rectified Linear Unit + * SOFTMAX: Softmax + * + */ +enum ActivationType { SIGMOID, RELU, SOFTMAX, NONE }; + +class Activation { + public: + + Activation() = default; + + /** + * @brief Construct a new Activation object + * + * @param activation Type of activation + * @param length Length of the input + */ + Activation(ActivationType activation, const unsigned int length); + + /** + * @brief Destroy the Activation object + * + */ + ~Activation(); + + /** + * @brief Run the activation function on the input + * + * @param d_input Pointer to the input vector on the device + */ + void activate(float* d_input); + + + private: + ActivationType activationType; + unsigned int length; + unsigned int gridSize; + + float* d_softmax_sum; + +}; + + +} // namespace CUDANet::Layers + +#endif // CUDANET_ACTIVATION_H \ No newline at end of file diff --git a/include/layers/conv2d.cuh b/include/layers/conv2d.cuh index 54b071e..53a06fd 100644 --- a/include/layers/conv2d.cuh +++ b/include/layers/conv2d.cuh @@ -4,7 +4,7 @@ #include #include -#include "activations.cuh" +#include "activation.cuh" #include "convolution.cuh" #include "ilayer.cuh" @@ -23,18 +23,18 @@ class Conv2d : public ILayer { * @param inputChannels Number of channels in the input matrix * @param kernelSize Width and height of the convolution kernel * @param stride Convolution stride - * @param padding Padding type ('SAME' or 'VALID') * @param numFilters Number of output filters - * @param activation Activation function ('RELU', 'SIGMOID' or 'NONE') + * @param padding Padding type ('SAME' or 'VALID') + * @param activationType Activation function type ('RELU', 'SIGMOID', 'SOFTMAX' or 'NONE') */ Conv2d( - int inputSize, - int inputChannels, - int kernelSize, - int stride, - Layers::Padding padding, - int numFilters, - Layers::Activation activation + int inputSize, + int inputChannels, + int kernelSize, + int stride, + int numFilters, + Layers::Padding padding, + Layers::ActivationType activationType ); /** @@ -67,17 +67,21 @@ class Conv2d : public ILayer { /** * @brief Get the output width (/ height) of the layer - * - * @return int + * + * @return int */ - int getOutputSize() { return outputSize; } + int getOutputSize() { + return outputSize; + } /** * @brief Get the padding size of the layer - * - * @return int + * + * @return int */ - int getPaddingSize() { return paddingSize; } + int getPaddingSize() { + return paddingSize; + } private: // Inputs diff --git a/include/layers/dense.cuh b/include/layers/dense.cuh index 58d4be5..b8fb207 100644 --- a/include/layers/dense.cuh +++ b/include/layers/dense.cuh @@ -20,9 +20,9 @@ class Dense : public ILayer { * * @param inputSize Size of the input vector * @param outputSize Size of the output vector - * @param activation Activation function ('RELU', 'SIGMOID' or 'NONE') + * @param activationType Activation function type ('RELU', 'SIGMOID', 'SOFTMAX' or 'NONE') */ - Dense(int inputSize, int outputSize, Layers::Activation activation); + Dense(int inputSize, int outputSize, Layers::ActivationType activationType); /** * @brief Destroy the Dense layer diff --git a/include/layers/ilayer.cuh b/include/layers/ilayer.cuh index d828364..f1ab5fd 100644 --- a/include/layers/ilayer.cuh +++ b/include/layers/ilayer.cuh @@ -6,15 +6,6 @@ namespace CUDANet::Layers { -/** - * @brief Activation functions - * - * SIGMOID: Sigmoid - * RELU: Rectified Linear Unit - * - */ -enum Activation { SIGMOID, RELU, NONE }; - /** * @brief Padding types * @@ -85,7 +76,6 @@ class ILayer { std::vector weights; std::vector biases; - Layers::Activation activation; }; } // namespace CUDANet::Layers diff --git a/src/kernels/activation_functions.cu b/src/kernels/activation_functions.cu new file mode 100644 index 0000000..296c7cb --- /dev/null +++ b/src/kernels/activation_functions.cu @@ -0,0 +1,79 @@ +#include + +#include "activation_functions.cuh" +#include "cuda_helper.cuh" + +__global__ void CUDANet::Kernels::sigmoid( + const float* __restrict__ src, + float* __restrict__ dst, + const unsigned int len +) { + int stride = gridDim.x * blockDim.x; + int tid = blockDim.x * blockIdx.x + threadIdx.x; + + for (int i = tid; i < len; i += stride) { + dst[i] = 1.0 / (1.0 + exp(-src[i])); + } +} + +__global__ void CUDANet::Kernels::relu( + const float* __restrict__ src, + float* __restrict__ dst, + const unsigned int len +) { + int stride = gridDim.x * blockDim.x; + int tid = blockDim.x * blockIdx.x + threadIdx.x; + + for (int i = tid; i < len; i += stride) { + dst[i] = src[i] < 0.0 ? 0.0 : src[i]; + } +} + +__global__ void CUDANet::Kernels::softmax_exp( + const float* __restrict__ src, + float* __restrict__ dst, + const unsigned int len +) { + int stride = gridDim.x * blockDim.x; + int tid = blockDim.x * blockIdx.x + threadIdx.x; + + for (int i = tid; i < len; i += stride) { + dst[i] = exp(src[i]); + } +} + +__global__ void CUDANet::Kernels::softmax_sum( + const float* __restrict__ d_vector, + float* __restrict__ d_output, + const unsigned int w +) { + __shared__ float partial_sum[BLOCK_SIZE]; + int i = blockIdx.x * blockDim.x * 2 + threadIdx.x; + partial_sum[threadIdx.x] = d_vector[i] + d_vector[i + blockDim.x]; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + partial_sum[threadIdx.x] += partial_sum[threadIdx.x + s]; + } + __syncthreads(); + } + + if (threadIdx.x == 0) { + d_output[blockIdx.x] = partial_sum[0]; + } +} + +__global__ void CUDANet::Kernels::softmax_div( + const float* __restrict__ src, + float* __restrict__ dst, + const float* __restrict__ sum, + const unsigned int len +) { + int stride = gridDim.x * blockDim.x; + int tid = blockDim.x * blockIdx.x + threadIdx.x; + + for (int i = tid; i < len; i += stride) { + dst[i] = src[i] / sum[0]; + } +} \ No newline at end of file diff --git a/src/kernels/activations.cu b/src/kernels/activations.cu deleted file mode 100644 index 82efd47..0000000 --- a/src/kernels/activations.cu +++ /dev/null @@ -1,29 +0,0 @@ -#include - -#include "activations.cuh" - -__global__ void CUDANet::Kernels::sigmoid( - const float* __restrict__ src, - float* __restrict__ dst, - int len -) { - int stride = gridDim.x * blockDim.x; - int tid = blockDim.x * blockIdx.x + threadIdx.x; - - for (int i = tid; i < len; i += stride) { - dst[i] = 1.0 / (1.0 + exp(-src[i])); - } -} - -__global__ void CUDANet::Kernels::relu( - const float* __restrict__ src, - float* __restrict__ dst, - int len -) { - int stride = gridDim.x * blockDim.x; - int tid = blockDim.x * blockIdx.x + threadIdx.x; - - for (int i = tid; i < len; i += stride) { - dst[i] = src[i] < 0.0 ? 0.0 : src[i]; - } -} diff --git a/src/kernels/matmul.cu b/src/kernels/matmul.cu index 5c664c4..93b6dfe 100644 --- a/src/kernels/matmul.cu +++ b/src/kernels/matmul.cu @@ -47,15 +47,3 @@ __global__ void CUDANet::Kernels::vec_vec_add( } d_output[tid] = d_vector1[tid] + d_vector2[tid]; } - -__global__ void CUDANet::Kernels::reduce_sum( - const float* __restrict__ d_vector, - float* __restrict__ d_output, - const unsigned int w -) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - - __shared__ float shared[BLOCK_SIZE]; - shared[threadIdx.x] = d_vector[tid]; - __syncthreads(); -} \ No newline at end of file diff --git a/src/layers/activation.cu b/src/layers/activation.cu new file mode 100644 index 0000000..94e9cce --- /dev/null +++ b/src/layers/activation.cu @@ -0,0 +1,60 @@ +#include "activation.cuh" + +#include "cuda_helper.cuh" +#include "activation_functions.cuh" + +using namespace CUDANet; + +Layers::Activation::Activation(ActivationType activation, const unsigned int length) + : activationType(activation), length(length) { + + if (activationType == SOFTMAX) { + d_softmax_sum = nullptr; + CUDA_CHECK(cudaMalloc((void**)&d_softmax_sum, sizeof(float) * length)); + } + + gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE; +} + +Layers::Activation::~Activation() { + if (activationType == SOFTMAX) { + cudaFree(d_softmax_sum); + } +} + +void Layers::Activation::activate(float* __restrict__ d_input) { + + switch (activationType) { + case SIGMOID: + Kernels::sigmoid<<>>( + d_input, d_input, length + ); + break; + + case RELU: + Kernels::relu<<>>( + d_input, d_input, length + ); + break; + case SOFTMAX: + Kernels::softmax_exp<<>>( + d_input, d_input, length + ); + + Kernels::softmax_sum<<>>( + d_input, d_softmax_sum, length + ); + + Kernels::softmax_sum<<<1, BLOCK_SIZE>>>( + d_softmax_sum, d_softmax_sum, length + ); + + Kernels::softmax_div<<>>( + d_input, d_input, d_softmax_sum, length + ); + break; + + default: + break; + } +} \ No newline at end of file diff --git a/src/layers/conv2d.cu b/src/layers/conv2d.cu index ca36956..2c4d7c7 100644 --- a/src/layers/conv2d.cu +++ b/src/layers/conv2d.cu @@ -1,7 +1,7 @@ #include #include -#include "activations.cuh" +#include "activation.cuh" #include "conv2d.cuh" #include "convolution.cuh" #include "cuda_helper.cuh" @@ -10,20 +10,19 @@ using namespace CUDANet; Layers::Conv2d::Conv2d( - int inputSize, - int inputChannels, - int kernelSize, - int stride, - Layers::Padding padding, - int numFilters, - Layers::Activation activation + int inputSize, + int inputChannels, + int kernelSize, + int stride, + int numFilters, + Layers::Padding padding, + Layers::ActivationType activationType ) : inputSize(inputSize), inputChannels(inputChannels), kernelSize(kernelSize), stride(stride), - numFilters(numFilters), - activation(activation) { + numFilters(numFilters) { switch (padding) { case SAME: outputSize = inputSize; @@ -39,10 +38,13 @@ Layers::Conv2d::Conv2d( break; } + activation = Layers::Activation( + activationType, outputSize * outputSize * numFilters + ); + d_output = nullptr; CUDA_CHECK(cudaMalloc( - (void**)&d_output, - sizeof(float) * outputSize * outputSize * numFilters + (void**)&d_output, sizeof(float) * outputSize * outputSize * numFilters )); weights.resize(kernelSize * kernelSize * inputChannels * numFilters); @@ -131,18 +133,8 @@ float* Layers::Conv2d::forward(const float* d_input) { d_biases, d_output, d_output, biases.size() ); - switch (activation) { - case SIGMOID: - Kernels::sigmoid<<<1, outputSize>>>(d_output, d_output, outputSize); - break; - - case RELU: - Kernels::relu<<<1, outputSize>>>(d_output, d_output, outputSize); - break; - - default: - break; - } + // Apply activation + activation.activate(d_output); CUDA_CHECK(cudaDeviceSynchronize()); diff --git a/src/layers/dense.cu b/src/layers/dense.cu index bd700f6..11896da 100644 --- a/src/layers/dense.cu +++ b/src/layers/dense.cu @@ -5,7 +5,7 @@ #include #include -#include "activations.cuh" +#include "activation.cuh" #include "cuda_helper.cuh" #include "dense.cuh" #include "matmul.cuh" @@ -15,13 +15,15 @@ using namespace CUDANet; Layers::Dense::Dense( int inputSize, int outputSize, - Layers::Activation activation + Layers::ActivationType activationType ) - : inputSize(inputSize), outputSize(outputSize), activation(activation) { + : inputSize(inputSize), outputSize(outputSize) { // Allocate memory for weights and biases weights.resize(outputSize * inputSize); biases.resize(outputSize); + activation = Layers::Activation(activationType, outputSize); + initializeWeights(); initializeBiases(); @@ -69,22 +71,7 @@ float* Layers::Dense::forward(const float* d_input) { d_biases, d_output, d_output, outputSize ); - switch (activation) { - case SIGMOID: - Kernels::sigmoid<<>>( - d_output, d_output, outputSize - ); - break; - - case RELU: - Kernels::relu<<>>( - d_output, d_output, outputSize - ); - break; - - default: - break; - } + activation.activate(d_output); CUDA_CHECK(cudaDeviceSynchronize()); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 3bf8597..af1ab20 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -2,10 +2,11 @@ find_package(GTest REQUIRED) include_directories(${GTEST_INCLUDE_DIRS}) add_executable(test_main + EXCLUDE_FROM_ALL layers/test_dense.cu layers/test_conv2d.cu layers/test_input.cu - kernels/test_activations.cu + kernels/test_activation_functions.cu kernels/test_padding.cu kernels/test_matmul.cu ) diff --git a/test/kernels/test_activations.cu b/test/kernels/test_activation_functions.cu similarity index 96% rename from test/kernels/test_activations.cu rename to test/kernels/test_activation_functions.cu index adfde76..e62d188 100644 --- a/test/kernels/test_activations.cu +++ b/test/kernels/test_activation_functions.cu @@ -3,7 +3,7 @@ #include -#include "activations.cuh" +#include "activation_functions.cuh" TEST(ActivationsTest, SigmoidSanityCheck) { diff --git a/test/layers/test_conv2d.cu b/test/layers/test_conv2d.cu index 24ffd27..366d920 100644 --- a/test/layers/test_conv2d.cu +++ b/test/layers/test_conv2d.cu @@ -8,21 +8,21 @@ class Conv2dTest : public ::testing::Test { protected: CUDANet::Layers::Conv2d commonTestSetup( - int inputSize, - int inputChannels, - int kernelSize, - int stride, - CUDANet::Layers::Padding padding, - int numFilters, - CUDANet::Layers::Activation activation, - std::vector& input, - float* kernels, - float*& d_input + int inputSize, + int inputChannels, + int kernelSize, + int stride, + int numFilters, + CUDANet::Layers::Padding padding, + CUDANet::Layers::ActivationType activationType, + std::vector& input, + float* kernels, + float*& d_input ) { // Create Conv2d layer CUDANet::Layers::Conv2d conv2d( - inputSize, inputChannels, kernelSize, stride, padding, numFilters, - activation + inputSize, inputChannels, kernelSize, stride, numFilters, padding, + activationType ); conv2d.setWeights(kernels); @@ -53,13 +53,14 @@ class Conv2dTest : public ::testing::Test { }; TEST_F(Conv2dTest, SimpleTest) { - int inputSize = 4; - int inputChannels = 1; - int kernelSize = 2; - int stride = 1; - CUDANet::Layers::Padding padding = CUDANet::Layers::Padding::VALID; - int numFilters = 1; - CUDANet::Layers::Activation activation = CUDANet::Layers::Activation::NONE; + int inputSize = 4; + int inputChannels = 1; + int kernelSize = 2; + int stride = 1; + int numFilters = 1; + CUDANet::Layers::Padding padding = CUDANet::Layers::Padding::VALID; + CUDANet::Layers::ActivationType activationType = + CUDANet::Layers::ActivationType::NONE; std::vector input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, @@ -75,8 +76,8 @@ TEST_F(Conv2dTest, SimpleTest) { float* d_output; CUDANet::Layers::Conv2d conv2d = commonTestSetup( - inputSize, inputChannels, kernelSize, stride, padding, numFilters, - activation, input, kernels.data(), d_input + inputSize, inputChannels, kernelSize, stride, numFilters, padding, + activationType, input, kernels.data(), d_input ); int outputSize = (inputSize - kernelSize) / stride + 1; @@ -102,13 +103,14 @@ TEST_F(Conv2dTest, SimpleTest) { } TEST_F(Conv2dTest, PaddedTest) { - int inputSize = 5; - int inputChannels = 3; - int kernelSize = 3; - int stride = 1; - CUDANet::Layers::Padding padding = CUDANet::Layers::Padding::SAME; - int numFilters = 2; - CUDANet::Layers::Activation activation = CUDANet::Layers::Activation::NONE; + int inputSize = 5; + int inputChannels = 3; + int kernelSize = 3; + int stride = 1; + int numFilters = 2; + CUDANet::Layers::Padding padding = CUDANet::Layers::Padding::SAME; + CUDANet::Layers::ActivationType activationType = + CUDANet::Layers::ActivationType::NONE; // clang-format off std::vector input = { @@ -164,8 +166,8 @@ TEST_F(Conv2dTest, PaddedTest) { float* d_output; CUDANet::Layers::Conv2d conv2d = commonTestSetup( - inputSize, inputChannels, kernelSize, stride, padding, numFilters, - activation, input, kernels.data(), d_input + inputSize, inputChannels, kernelSize, stride, numFilters, padding, + activationType, input, kernels.data(), d_input ); EXPECT_EQ(inputSize, conv2d.getOutputSize()); @@ -203,13 +205,14 @@ TEST_F(Conv2dTest, PaddedTest) { } TEST_F(Conv2dTest, StridedPaddedConvolution) { - int inputSize = 5; - int inputChannels = 2; - int kernelSize = 3; - int stride = 2; - int numFilters = 2; - CUDANet::Layers::Padding padding = CUDANet::Layers::Padding::SAME; - CUDANet::Layers::Activation activation = CUDANet::Layers::Activation::RELU; + int inputSize = 5; + int inputChannels = 2; + int kernelSize = 3; + int stride = 2; + int numFilters = 2; + CUDANet::Layers::Padding padding = CUDANet::Layers::Padding::SAME; + CUDANet::Layers::ActivationType activationType = + CUDANet::Layers::ActivationType::RELU; // clang-format off std::vector input = { @@ -250,8 +253,8 @@ TEST_F(Conv2dTest, StridedPaddedConvolution) { float* d_output; CUDANet::Layers::Conv2d conv2d = commonTestSetup( - inputSize, inputChannels, kernelSize, stride, padding, numFilters, - activation, input, kernels.data(), d_input + inputSize, inputChannels, kernelSize, stride, numFilters, padding, + activationType, input, kernels.data(), d_input ); EXPECT_EQ(inputSize, conv2d.getOutputSize()); diff --git a/test/layers/test_dense.cu b/test/layers/test_dense.cu index ae44aa0..c53af44 100644 --- a/test/layers/test_dense.cu +++ b/test/layers/test_dense.cu @@ -3,7 +3,7 @@ #include -#include "activations.cuh" +#include "activation.cuh" #include "dense.cuh" class DenseLayerTest : public ::testing::Test { @@ -15,10 +15,10 @@ class DenseLayerTest : public ::testing::Test { float* weights, float* biases, float*& d_input, - CUDANet::Layers::Activation activation + CUDANet::Layers::ActivationType activationType ) { // Create Dense layer - CUDANet::Layers::Dense denseLayer(inputSize, outputSize, activation); + CUDANet::Layers::Dense denseLayer(inputSize, outputSize, activationType); // Set weights and biases denseLayer.setWeights(weights); @@ -53,7 +53,7 @@ TEST_F(DenseLayerTest, Init) { int outputSize = j; CUDANet::Layers::Dense denseLayer( - inputSize, outputSize, CUDANet::Layers::Activation::SIGMOID + inputSize, outputSize, CUDANet::Layers::ActivationType::SIGMOID ); } } @@ -74,7 +74,7 @@ TEST_F(DenseLayerTest, setWeights) { // clang-format on CUDANet::Layers::Dense denseLayer( - inputSize, outputSize, CUDANet::Layers::Activation::SIGMOID + inputSize, outputSize, CUDANet::Layers::ActivationType::SIGMOID ); denseLayer.setWeights(weights.data()); @@ -101,7 +101,7 @@ TEST_F(DenseLayerTest, ForwardUnitWeightMatrixLinear) { CUDANet::Layers::Dense denseLayer = commonTestSetup( inputSize, outputSize, input, weights.data(), biases.data(), d_input, - CUDANet::Layers::Activation::NONE + CUDANet::Layers::ActivationType::NONE ); d_output = denseLayer.forward(d_input); @@ -142,7 +142,7 @@ TEST_F(DenseLayerTest, ForwardRandomWeightMatrixRelu) { CUDANet::Layers::Dense denseLayer = commonTestSetup( inputSize, outputSize, input, weights.data(), biases.data(), d_input, - CUDANet::Layers::Activation::RELU + CUDANet::Layers::ActivationType::RELU ); d_output = denseLayer.forward(d_input); @@ -187,7 +187,7 @@ TEST_F(DenseLayerTest, ForwardRandomWeightMatrixSigmoid) { CUDANet::Layers::Dense denseLayer = commonTestSetup( inputSize, outputSize, input, weights.data(), biases.data(), d_input, - CUDANet::Layers::Activation::SIGMOID + CUDANet::Layers::ActivationType::SIGMOID ); d_output = denseLayer.forward(d_input);