From 0c22fac64e1df83b62f6b6368e2bbffe0bc99674 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sun, 17 Mar 2024 16:08:53 +0100 Subject: [PATCH] Add toplevel CUDANet namespace --- include/kernels/activations.cuh | 4 +- include/kernels/convolution.cuh | 38 ++++++++-------- include/kernels/matmul.cuh | 39 ++++++++++------ include/layers/conv2d.cuh | 4 +- include/layers/dense.cuh | 12 ++--- include/layers/ilayer.cuh | 4 +- include/layers/input.cuh | 4 +- src/kernels/activations.cu | 9 ++-- src/kernels/convolution.cu | 37 +++++++-------- src/kernels/matmul.cu | 41 ++++++++++------- src/layers/conv2d.cu | 2 + src/layers/dense.cu | 2 + src/layers/input.cu | 2 + test/kernels/test_activations.cu | 2 +- test/kernels/test_matmul.cu | 2 +- test/kernels/test_padding.cu | 2 +- test/layers/test_conv2d.cu | 78 ++++++++++++++++---------------- test/layers/test_dense.cu | 38 ++++++++-------- test/layers/test_input.cu | 12 ++--- 19 files changed, 183 insertions(+), 149 deletions(-) diff --git a/include/kernels/activations.cuh b/include/kernels/activations.cuh index f5c6a8e..5e85edf 100644 --- a/include/kernels/activations.cuh +++ b/include/kernels/activations.cuh @@ -1,7 +1,7 @@ #ifndef CUDANET_ACTIVATIONS_H #define CUDANET_ACTIVATIONS_H -namespace Kernels { +namespace CUDANet::Kernels { /** * @brief Sigmoid activation function kernel @@ -23,6 +23,6 @@ sigmoid(const float* __restrict__ src, float* __restrict__ dst, int len); __global__ void relu(const float* __restrict__ src, float* __restrict__ dst, int len); -} // namespace Kernels +} // namespace CUDANet::Kernels #endif // CUDANET_ACTIVATIONS_H \ No newline at end of file diff --git a/include/kernels/convolution.cuh b/include/kernels/convolution.cuh index ffa8227..59a3b0e 100644 --- a/include/kernels/convolution.cuh +++ b/include/kernels/convolution.cuh @@ -1,11 +1,11 @@ #ifndef CUDANET_CONVOLUTION_H #define CUDANET_CONVOLUTION_H -namespace Kernels { +namespace CUDANet::Kernels { /** * @brief Kernel that pads the input matrix with zeros - * + * * @param d_input Device pointer to the input matrix (as vector) * @param d_padded Device pointer to the padded matrix (as vector) * @param w Width of the input matrix @@ -14,17 +14,17 @@ namespace Kernels { * @param p Padding size */ __global__ void padding( - const float* d_input, - float* d_padded, - int w, - int h, - int n, - int p + const float* __restrict__ d_input, + float* __restrict__ d_padded, + const unsigned int w, + const unsigned int h, + const unsigned int n, + const unsigned int p ); /** * @brief Convolution kernel - * + * * @param d_input Device pointer to the input matrix * @param d_kernel Device pointer to the convolution kernel * @param d_output Device pointer to the output matrix @@ -36,17 +36,17 @@ __global__ void padding( * @param outputSize Width and height of the output matrix */ __global__ void convolution( - const float* d_input, - const float* d_kernel, - float* d_output, - int inputSize, - int nChannels, - int kernelSize, - int stride, - int nFilters, - int outputSize + const float* __restrict__ d_input, + const float* __restrict__ d_kernel, + float* __restrict__ d_output, + const unsigned int inputSize, + const unsigned int nChannels, + const unsigned int kernelSize, + const unsigned int stride, + const unsigned int nFilters, + const unsigned int outputSize ); -} // namespace Kernels +} // namespace CUDANet::Kernels #endif // CUDANET_CONVOLUTION_H \ No newline at end of file diff --git a/include/kernels/matmul.cuh b/include/kernels/matmul.cuh index 43a7f5e..09de250 100644 --- a/include/kernels/matmul.cuh +++ b/include/kernels/matmul.cuh @@ -1,11 +1,11 @@ #ifndef CUDANET_MATMUL_H #define CUDANET_MATMUL_H -namespace Kernels { +namespace CUDANet::Kernels { /** * @brief Matrix vector multiplication kernel - * + * * @param d_matrix Device pointer to matrix * @param d_vector Device pointer to vector * @param d_output Device pointer to output vector @@ -13,28 +13,41 @@ namespace Kernels { * @param h Height of the matrix */ __global__ void mat_vec_mul( - const float* d_matrix, - const float* d_vector, - float* d_output, - int w, - int h + const float* __restrict__ d_matrix, + const float* __restrict__ d_vector, + float* __restrict__ d_output, + const unsigned int w, + const unsigned int h ); /** * @brief Vector vector addition kernel - * + * * @param d_vector1 Device pointer to first vector * @param d_vector2 Device pointer to second vector * @param d_output Device pointer to output vector * @param w Length of the vectors */ __global__ void vec_vec_add( - const float* d_vector1, - const float* d_vector2, - float* d_output, - int w + const float* __restrict__ d_vector1, + const float* __restrict__ d_vector2, + float* __restrict__ d_output, + const unsigned int w ); -} // namespace Kernels +/** + * @brief + * + * @param d_vector Device pointer to vector + * @param d_output Device pointer to output vector + * @param w Length of the vector + */ +__global__ void reduce_sum( + const float* __restrict__ d_vector, + float* __restrict__ d_output, + const unsigned int w +); + +} // namespace CUDANet::Kernels #endif // CUDANET_MATMUL_H \ No newline at end of file diff --git a/include/layers/conv2d.cuh b/include/layers/conv2d.cuh index 92d6e05..54b071e 100644 --- a/include/layers/conv2d.cuh +++ b/include/layers/conv2d.cuh @@ -8,7 +8,7 @@ #include "convolution.cuh" #include "ilayer.cuh" -namespace Layers { +namespace CUDANet::Layers { /** * @brief 2D convolutional layer @@ -125,6 +125,6 @@ class Conv2d : public ILayer { void toCuda(); }; -} // namespace Layers +} // namespace CUDANet::Layers #endif // CUDANET_CONV_LAYER_H diff --git a/include/layers/dense.cuh b/include/layers/dense.cuh index 93ce6ab..58d4be5 100644 --- a/include/layers/dense.cuh +++ b/include/layers/dense.cuh @@ -7,7 +7,7 @@ #include "ilayer.cuh" -namespace Layers { +namespace CUDANet::Layers { /** * @brief Dense (fully connected) layer @@ -53,8 +53,8 @@ class Dense : public ILayer { void setBiases(const float* biases); private: - int inputSize; - int outputSize; + unsigned int inputSize; + unsigned int outputSize; float* d_output; @@ -67,8 +67,8 @@ class Dense : public ILayer { Layers::Activation activation; // Precompute kernel launch parameters - int forwardGridSize; - int biasGridSize; + unsigned int forwardGridSize; + unsigned int biasGridSize; /** * @brief Initialize the weights to zeros @@ -89,6 +89,6 @@ class Dense : public ILayer { void toCuda(); }; -} // namespace Layers +} // namespace CUDANet::Layers #endif // CUDANET_DENSE_LAYER_H diff --git a/include/layers/ilayer.cuh b/include/layers/ilayer.cuh index edd8ebe..d828364 100644 --- a/include/layers/ilayer.cuh +++ b/include/layers/ilayer.cuh @@ -4,7 +4,7 @@ #include -namespace Layers { +namespace CUDANet::Layers { /** * @brief Activation functions @@ -88,6 +88,6 @@ class ILayer { Layers::Activation activation; }; -} // namespace Layers +} // namespace CUDANet::Layers #endif // CUDANET_I_LAYERH \ No newline at end of file diff --git a/include/layers/input.cuh b/include/layers/input.cuh index bfa36c6..ee6ecd7 100644 --- a/include/layers/input.cuh +++ b/include/layers/input.cuh @@ -3,7 +3,7 @@ #include -namespace Layers { +namespace CUDANet::Layers { /** * @brief Input layer, just copies the input to the device @@ -45,6 +45,6 @@ class Input : public ILayer { float* d_output; }; -} // namespace Layers +} // namespace CUDANet::Layers #endif // CUDANET_INPUT_LAYER_H \ No newline at end of file diff --git a/src/kernels/activations.cu b/src/kernels/activations.cu index 7fbbf74..82efd47 100644 --- a/src/kernels/activations.cu +++ b/src/kernels/activations.cu @@ -2,7 +2,7 @@ #include "activations.cuh" -__global__ void Kernels::sigmoid( +__global__ void CUDANet::Kernels::sigmoid( const float* __restrict__ src, float* __restrict__ dst, int len @@ -15,8 +15,11 @@ __global__ void Kernels::sigmoid( } } -__global__ void -Kernels::relu(const float* __restrict__ src, float* __restrict__ dst, int len) { +__global__ void CUDANet::Kernels::relu( + const float* __restrict__ src, + float* __restrict__ dst, + int len +) { int stride = gridDim.x * blockDim.x; int tid = blockDim.x * blockIdx.x + threadIdx.x; diff --git a/src/kernels/convolution.cu b/src/kernels/convolution.cu index 27daefc..1a5eaf4 100644 --- a/src/kernels/convolution.cu +++ b/src/kernels/convolution.cu @@ -1,6 +1,7 @@ -#include "convolution.cuh" #include +#include "convolution.cuh" + /* Pads matrix width x height x n_channels to width + 2 * padding x height + 2 * padding x n_channels Matrix is represented as a pointer to a vector @@ -47,13 +48,13 @@ pre-allocated) n: Number of channels in input matrix p: Padding */ -__global__ void Kernels::padding( - const float* d_input, - float* d_padded, - int w, - int h, - int n, - int p +__global__ void CUDANet::Kernels::padding( + const float* __restrict__ d_input, + float* __restrict__ d_padded, + const unsigned int w, + const unsigned int h, + const unsigned int n, + const unsigned int p ) { int tid = blockDim.x * blockIdx.x + threadIdx.x; @@ -78,16 +79,16 @@ __global__ void Kernels::padding( } } -__global__ void Kernels::convolution( - const float* d_input, - const float* d_kernel, - float* d_output, - int inputSize, - int nChannels, - int kernelSize, - int stride, - int nFilters, - int outputSize +__global__ void CUDANet::Kernels::convolution( + const float* __restrict__ d_input, + const float* __restrict__ d_kernel, + float* __restrict__ d_output, + const unsigned int inputSize, + const unsigned int nChannels, + const unsigned int kernelSize, + const unsigned int stride, + const unsigned int nFilters, + const unsigned int outputSize ) { int tid = blockDim.x * blockIdx.x + threadIdx.x; diff --git a/src/kernels/matmul.cu b/src/kernels/matmul.cu index a176697..5c664c4 100644 --- a/src/kernels/matmul.cu +++ b/src/kernels/matmul.cu @@ -1,14 +1,12 @@ #include "cuda_helper.cuh" #include "matmul.cuh" -#define SHARED_SIZE 128 * 4 - -__global__ void Kernels::mat_vec_mul( +__global__ void CUDANet::Kernels::mat_vec_mul( const float* __restrict__ d_matrix, const float* __restrict__ d_vector, float* __restrict__ d_output, - int w, - int h + const unsigned int w, + const unsigned int h ) { int tid = blockDim.x * blockIdx.x + threadIdx.x; @@ -16,9 +14,8 @@ __global__ void Kernels::mat_vec_mul( float temp = 0.0f; - #pragma unroll - for (unsigned int i = 0; i < (w + BLOCK_SIZE - 1) / BLOCK_SIZE; i++) - { +#pragma unroll + for (unsigned int i = 0; i < (w + BLOCK_SIZE - 1) / BLOCK_SIZE; i++) { if (i * BLOCK_SIZE + threadIdx.x < w) { shared[threadIdx.x] = d_vector[i * BLOCK_SIZE + threadIdx.x]; } else { @@ -27,22 +24,22 @@ __global__ void Kernels::mat_vec_mul( __syncthreads(); - for (unsigned int j = 0; j < BLOCK_SIZE; j++) - { +#pragma unroll + for (unsigned int j = 0; j < BLOCK_SIZE; j++) { temp += d_matrix[tid * w + i * BLOCK_SIZE + j] * shared[j]; } __syncthreads(); } - + d_output[tid] = temp; } -__global__ void Kernels::vec_vec_add( - const float* d_vector1, - const float* d_vector2, - float* d_output, - int w +__global__ void CUDANet::Kernels::vec_vec_add( + const float* __restrict__ d_vector1, + const float* __restrict__ d_vector2, + float* __restrict__ d_output, + const unsigned int w ) { int tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid >= w) { @@ -50,3 +47,15 @@ __global__ void Kernels::vec_vec_add( } d_output[tid] = d_vector1[tid] + d_vector2[tid]; } + +__global__ void CUDANet::Kernels::reduce_sum( + const float* __restrict__ d_vector, + float* __restrict__ d_output, + const unsigned int w +) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + + __shared__ float shared[BLOCK_SIZE]; + shared[threadIdx.x] = d_vector[tid]; + __syncthreads(); +} \ No newline at end of file diff --git a/src/layers/conv2d.cu b/src/layers/conv2d.cu index 1c381d9..ca36956 100644 --- a/src/layers/conv2d.cu +++ b/src/layers/conv2d.cu @@ -7,6 +7,8 @@ #include "cuda_helper.cuh" #include "matmul.cuh" +using namespace CUDANet; + Layers::Conv2d::Conv2d( int inputSize, int inputChannels, diff --git a/src/layers/dense.cu b/src/layers/dense.cu index 289c8a0..bd700f6 100644 --- a/src/layers/dense.cu +++ b/src/layers/dense.cu @@ -10,6 +10,8 @@ #include "dense.cuh" #include "matmul.cuh" +using namespace CUDANet; + Layers::Dense::Dense( int inputSize, int outputSize, diff --git a/src/layers/input.cu b/src/layers/input.cu index 30b98fb..d50dde3 100644 --- a/src/layers/input.cu +++ b/src/layers/input.cu @@ -1,6 +1,8 @@ #include "cuda_helper.cuh" #include "input.cuh" +using namespace CUDANet; + Layers::Input::Input(int inputSize) : inputSize(inputSize) { d_output = nullptr; CUDA_CHECK(cudaMalloc((void**)&d_output, sizeof(float) * inputSize)); diff --git a/test/kernels/test_activations.cu b/test/kernels/test_activations.cu index d1d44ae..adfde76 100644 --- a/test/kernels/test_activations.cu +++ b/test/kernels/test_activations.cu @@ -25,7 +25,7 @@ TEST(ActivationsTest, SigmoidSanityCheck) { cudaStatus = cudaMemcpy(d_input, input, sizeof(float) * 3, cudaMemcpyHostToDevice); EXPECT_EQ(cudaStatus, cudaSuccess); - Kernels::sigmoid<<<1, 3>>>(d_input, d_output, 3); + CUDANet::Kernels::sigmoid<<<1, 3>>>(d_input, d_output, 3); cudaStatus = cudaDeviceSynchronize(); EXPECT_EQ(cudaStatus, cudaSuccess); diff --git a/test/kernels/test_matmul.cu b/test/kernels/test_matmul.cu index ecde00e..e1c89af 100644 --- a/test/kernels/test_matmul.cu +++ b/test/kernels/test_matmul.cu @@ -44,7 +44,7 @@ TEST(MatMulTest, MatVecMulTest) { int THREADS_PER_BLOCK = std::max(w, h); int BLOCKS = 1; - Kernels::mat_vec_mul<<>>(d_matrix, d_vector, d_output, w, h); + CUDANet::Kernels::mat_vec_mul<<>>(d_matrix, d_vector, d_output, w, h); cudaStatus = cudaDeviceSynchronize(); EXPECT_EQ(cudaStatus, cudaSuccess); diff --git a/test/kernels/test_padding.cu b/test/kernels/test_padding.cu index aadec4c..7e04cb0 100644 --- a/test/kernels/test_padding.cu +++ b/test/kernels/test_padding.cu @@ -51,7 +51,7 @@ TEST(PaddingTest, SimplePaddingTest) { int THREADS_PER_BLOCK = 64; int BLOCKS = paddedSize / THREADS_PER_BLOCK + 1; - Kernels::padding<<>>( + CUDANet::Kernels::padding<<>>( d_input, d_padded, w, h, n, p ); cudaStatus = cudaDeviceSynchronize(); diff --git a/test/layers/test_conv2d.cu b/test/layers/test_conv2d.cu index 993a91b..24ffd27 100644 --- a/test/layers/test_conv2d.cu +++ b/test/layers/test_conv2d.cu @@ -7,20 +7,20 @@ class Conv2dTest : public ::testing::Test { protected: - Layers::Conv2d commonTestSetup( - int inputSize, - int inputChannels, - int kernelSize, - int stride, - Layers::Padding padding, - int numFilters, - Layers::Activation activation, - std::vector& input, - float* kernels, - float*& d_input + CUDANet::Layers::Conv2d commonTestSetup( + int inputSize, + int inputChannels, + int kernelSize, + int stride, + CUDANet::Layers::Padding padding, + int numFilters, + CUDANet::Layers::Activation activation, + std::vector& input, + float* kernels, + float*& d_input ) { // Create Conv2d layer - Layers::Conv2d conv2d( + CUDANet::Layers::Conv2d conv2d( inputSize, inputChannels, kernelSize, stride, padding, numFilters, activation ); @@ -53,13 +53,13 @@ class Conv2dTest : public ::testing::Test { }; TEST_F(Conv2dTest, SimpleTest) { - int inputSize = 4; - int inputChannels = 1; - int kernelSize = 2; - int stride = 1; - Layers::Padding padding = Layers::Padding::VALID; - int numFilters = 1; - Layers::Activation activation = Layers::Activation::NONE; + int inputSize = 4; + int inputChannels = 1; + int kernelSize = 2; + int stride = 1; + CUDANet::Layers::Padding padding = CUDANet::Layers::Padding::VALID; + int numFilters = 1; + CUDANet::Layers::Activation activation = CUDANet::Layers::Activation::NONE; std::vector input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, @@ -74,7 +74,7 @@ TEST_F(Conv2dTest, SimpleTest) { float* d_input; float* d_output; - Layers::Conv2d conv2d = commonTestSetup( + CUDANet::Layers::Conv2d conv2d = commonTestSetup( inputSize, inputChannels, kernelSize, stride, padding, numFilters, activation, input, kernels.data(), d_input ); @@ -102,13 +102,13 @@ TEST_F(Conv2dTest, SimpleTest) { } TEST_F(Conv2dTest, PaddedTest) { - int inputSize = 5; - int inputChannels = 3; - int kernelSize = 3; - int stride = 1; - Layers::Padding padding = Layers::Padding::SAME; - int numFilters = 2; - Layers::Activation activation = Layers::Activation::NONE; + int inputSize = 5; + int inputChannels = 3; + int kernelSize = 3; + int stride = 1; + CUDANet::Layers::Padding padding = CUDANet::Layers::Padding::SAME; + int numFilters = 2; + CUDANet::Layers::Activation activation = CUDANet::Layers::Activation::NONE; // clang-format off std::vector input = { @@ -163,7 +163,7 @@ TEST_F(Conv2dTest, PaddedTest) { float* d_input; float* d_output; - Layers::Conv2d conv2d = commonTestSetup( + CUDANet::Layers::Conv2d conv2d = commonTestSetup( inputSize, inputChannels, kernelSize, stride, padding, numFilters, activation, input, kernels.data(), d_input ); @@ -177,7 +177,8 @@ TEST_F(Conv2dTest, PaddedTest) { ); cudaMemcpy( output.data(), d_output, - sizeof(float) * conv2d.getOutputSize() * conv2d.getOutputSize() * numFilters, + sizeof(float) * conv2d.getOutputSize() * conv2d.getOutputSize() * + numFilters, cudaMemcpyDeviceToHost ); @@ -202,13 +203,13 @@ TEST_F(Conv2dTest, PaddedTest) { } TEST_F(Conv2dTest, StridedPaddedConvolution) { - int inputSize = 5; - int inputChannels = 2; - int kernelSize = 3; - int stride = 2; - int numFilters = 2; - Layers::Padding padding = Layers::Padding::SAME; - Layers::Activation activation = Layers::Activation::RELU; + int inputSize = 5; + int inputChannels = 2; + int kernelSize = 3; + int stride = 2; + int numFilters = 2; + CUDANet::Layers::Padding padding = CUDANet::Layers::Padding::SAME; + CUDANet::Layers::Activation activation = CUDANet::Layers::Activation::RELU; // clang-format off std::vector input = { @@ -248,7 +249,7 @@ TEST_F(Conv2dTest, StridedPaddedConvolution) { float* d_input; float* d_output; - Layers::Conv2d conv2d = commonTestSetup( + CUDANet::Layers::Conv2d conv2d = commonTestSetup( inputSize, inputChannels, kernelSize, stride, padding, numFilters, activation, input, kernels.data(), d_input ); @@ -262,7 +263,8 @@ TEST_F(Conv2dTest, StridedPaddedConvolution) { ); cudaMemcpy( output.data(), d_output, - sizeof(float) * conv2d.getOutputSize() * conv2d.getOutputSize() * numFilters, + sizeof(float) * conv2d.getOutputSize() * conv2d.getOutputSize() * + numFilters, cudaMemcpyDeviceToHost ); diff --git a/test/layers/test_dense.cu b/test/layers/test_dense.cu index 9c59a33..ae44aa0 100644 --- a/test/layers/test_dense.cu +++ b/test/layers/test_dense.cu @@ -8,17 +8,17 @@ class DenseLayerTest : public ::testing::Test { protected: - Layers::Dense commonTestSetup( - int inputSize, - int outputSize, - std::vector& input, - float* weights, - float* biases, - float*& d_input, - Layers::Activation activation + CUDANet::Layers::Dense commonTestSetup( + int inputSize, + int outputSize, + std::vector& input, + float* weights, + float* biases, + float*& d_input, + CUDANet::Layers::Activation activation ) { // Create Dense layer - Layers::Dense denseLayer(inputSize, outputSize, activation); + CUDANet::Layers::Dense denseLayer(inputSize, outputSize, activation); // Set weights and biases denseLayer.setWeights(weights); @@ -52,8 +52,8 @@ TEST_F(DenseLayerTest, Init) { int inputSize = i; int outputSize = j; - Layers::Dense denseLayer( - inputSize, outputSize, Layers::Activation::SIGMOID + CUDANet::Layers::Dense denseLayer( + inputSize, outputSize, CUDANet::Layers::Activation::SIGMOID ); } } @@ -73,8 +73,8 @@ TEST_F(DenseLayerTest, setWeights) { }; // clang-format on - Layers::Dense denseLayer( - inputSize, outputSize, Layers::Activation::SIGMOID + CUDANet::Layers::Dense denseLayer( + inputSize, outputSize, CUDANet::Layers::Activation::SIGMOID ); denseLayer.setWeights(weights.data()); @@ -99,9 +99,9 @@ TEST_F(DenseLayerTest, ForwardUnitWeightMatrixLinear) { float* d_input; float* d_output; - Layers::Dense denseLayer = commonTestSetup( + CUDANet::Layers::Dense denseLayer = commonTestSetup( inputSize, outputSize, input, weights.data(), biases.data(), d_input, - Layers::Activation::NONE + CUDANet::Layers::Activation::NONE ); d_output = denseLayer.forward(d_input); @@ -140,9 +140,9 @@ TEST_F(DenseLayerTest, ForwardRandomWeightMatrixRelu) { float* d_input; float* d_output; - Layers::Dense denseLayer = commonTestSetup( + CUDANet::Layers::Dense denseLayer = commonTestSetup( inputSize, outputSize, input, weights.data(), biases.data(), d_input, - Layers::Activation::RELU + CUDANet::Layers::Activation::RELU ); d_output = denseLayer.forward(d_input); @@ -185,9 +185,9 @@ TEST_F(DenseLayerTest, ForwardRandomWeightMatrixSigmoid) { float* d_input; float* d_output; - Layers::Dense denseLayer = commonTestSetup( + CUDANet::Layers::Dense denseLayer = commonTestSetup( inputSize, outputSize, input, weights.data(), biases.data(), d_input, - Layers::Activation::SIGMOID + CUDANet::Layers::Activation::SIGMOID ); d_output = denseLayer.forward(d_input); diff --git a/test/layers/test_input.cu b/test/layers/test_input.cu index f88483b..1986d3c 100644 --- a/test/layers/test_input.cu +++ b/test/layers/test_input.cu @@ -1,16 +1,16 @@ #include -#include "input.cuh" #include "cuda_helper.cuh" - +#include "input.cuh" TEST(InputLayerTest, Init) { std::vector input = {0.573f, 0.619f, 0.732f, 0.055f, 0.243f, 0.316f}; - Layers::Input inputLayer(6); - float* d_output = inputLayer.forward(input.data()); + CUDANet::Layers::Input inputLayer(6); + float* d_output = inputLayer.forward(input.data()); std::vector output(6); - CUDA_CHECK(cudaMemcpy(output.data(), d_output, sizeof(float) * 6, cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy( + output.data(), d_output, sizeof(float) * 6, cudaMemcpyDeviceToHost + )); EXPECT_EQ(input, output); - } \ No newline at end of file