From e0178e2d5c7aa6df32340fb0afb8292bbbd0dc75 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Mon, 11 Mar 2024 20:39:44 +0100 Subject: [PATCH] Cleanup and refactor --- include/layers/conv2d.cuh | 13 ++--- include/layers/dense.cuh | 4 +- include/layers/ilayer.cuh | 23 +++++++-- src/layers/conv2d.cu | 30 +++++++----- src/layers/dense.cu | 41 ++++------------ test/layers/test_conv2d.cu | 8 ++-- test/layers/test_dense.cu | 97 ++++++++++++++++++++------------------ 7 files changed, 108 insertions(+), 108 deletions(-) diff --git a/include/layers/conv2d.cuh b/include/layers/conv2d.cuh index 4c5a1cb..63e29ca 100644 --- a/include/layers/conv2d.cuh +++ b/include/layers/conv2d.cuh @@ -6,10 +6,11 @@ #include "activations.cuh" #include "padding.cuh" +#include "ilayer.cuh" namespace Layers { -class Conv2d { +class Conv2d : public ILayer { public: Conv2d( int inputSize, @@ -26,8 +27,8 @@ class Conv2d { int outputSize; void forward(const float* d_input, float* d_output); - void setKernels(const std::vector& kernels_input); - + void setWeights(const float* weights_input); + void setBiases(const float* biases_input); void host_conv(const float* input, float* output); private: @@ -42,18 +43,18 @@ class Conv2d { int numFilters; // Kernels - std::vector kernels; + std::vector weights; std::vector biases; // Cuda - float* d_kernels; + float* d_weights; float* d_biases; float* d_padded; // Kernels Activation activation; - void initializeKernels(); + void initializeWeights(); void initializeBiases(); void toCuda(); }; diff --git a/include/layers/dense.cuh b/include/layers/dense.cuh index 20afbf1..47d650f 100644 --- a/include/layers/dense.cuh +++ b/include/layers/dense.cuh @@ -19,8 +19,8 @@ class Dense : public ILayer { ~Dense(); void forward(const float* d_input, float* d_output); - void setWeights(const std::vector>& weights); - void setBiases(const std::vector& biases); + void setWeights(const float* weights); + void setBiases(const float* biases); private: int inputSize; diff --git a/include/layers/ilayer.cuh b/include/layers/ilayer.cuh index 2c0e826..7124a46 100644 --- a/include/layers/ilayer.cuh +++ b/include/layers/ilayer.cuh @@ -10,9 +10,26 @@ class ILayer { public: virtual ~ILayer() {} - virtual void forward(const float* input, float* output) = 0; - virtual void setWeights(const std::vector>& weights) = 0; - virtual void setBiases(const std::vector& biases) = 0; + virtual void forward(const float* input, float* output) = 0; + virtual void setWeights(const float* weights) = 0; + virtual void setBiases(const float* biases) = 0; + + private: + virtual void initializeWeights() = 0; + virtual void initializeBiases() = 0; + + virtual void toCuda() = 0; + + int inputSize; + int outputSize; + + float* d_weights; + float* d_biases; + + std::vector weights; + std::vector biases; + + Activation activation; }; } // namespace Layers diff --git a/src/layers/conv2d.cu b/src/layers/conv2d.cu index 8d82d2d..fdfd27e 100644 --- a/src/layers/conv2d.cu +++ b/src/layers/conv2d.cu @@ -23,7 +23,6 @@ Layers::Conv2d::Conv2d( stride(stride), numFilters(numFilters), activation(activation) { - // Allocate memory for kernels switch (padding) { @@ -41,12 +40,12 @@ Layers::Conv2d::Conv2d( break; } - kernels.resize(kernelSize * kernelSize * inputChannels * numFilters); - initializeKernels(); + weights.resize(kernelSize * kernelSize * inputChannels * numFilters); + initializeWeights(); - d_kernels = nullptr; + d_weights = nullptr; CUDA_CHECK(cudaMalloc( - (void**)&d_kernels, + (void**)&d_weights, sizeof(float) * kernelSize * kernelSize * inputChannels * numFilters )); @@ -68,27 +67,32 @@ Layers::Conv2d::Conv2d( } Layers::Conv2d::~Conv2d() { - cudaFree(d_kernels); + cudaFree(d_weights); cudaFree(d_biases); cudaFree(d_padded); } -void Layers::Conv2d::initializeKernels() { - std::fill(kernels.begin(), kernels.end(), 0.0f); +void Layers::Conv2d::initializeWeights() { + std::fill(weights.begin(), weights.end(), 0.0f); } void Layers::Conv2d::initializeBiases() { std::fill(biases.begin(), biases.end(), 0.0f); } -void Layers::Conv2d::setKernels(const std::vector& kernels_input) { - std::copy(kernels_input.begin(), kernels_input.end(), kernels.begin()); +void Layers::Conv2d::setWeights(const float* weights_input) { + std::copy(weights_input, weights_input + weights.size(), weights.begin()); + toCuda(); +} + +void Layers::Conv2d::setBiases(const float* biases_input) { + std::copy(biases_input, biases_input + biases.size(), biases.begin()); toCuda(); } void Layers::Conv2d::toCuda() { CUDA_CHECK(cudaMemcpy( - d_kernels, kernels.data(), + d_weights, weights.data(), sizeof(float) * kernelSize * kernelSize * inputChannels * numFilters, cudaMemcpyHostToDevice )); @@ -112,7 +116,7 @@ void Layers::Conv2d::forward(const float* d_input, float* d_output) { // Convolve THREADS_PER_BLOCK = outputSize * outputSize * numFilters; convolution_kernel<<<1, THREADS_PER_BLOCK>>>( - d_padded, d_kernels, d_output, inputSize + (2 * paddingSize), + d_padded, d_weights, d_output, inputSize + (2 * paddingSize), inputChannels, kernelSize, stride, numFilters, outputSize ); @@ -155,7 +159,7 @@ void Layers::Conv2d::host_conv(const float* input, float* output) { (i * stride + k) * inputSize + (j * stride + l); - sum += kernels[kernelIndex] * input[inputIndex]; + sum += weights[kernelIndex] * input[inputIndex]; } } } diff --git a/src/layers/dense.cu b/src/layers/dense.cu index 138d456..c888692 100644 --- a/src/layers/dense.cu +++ b/src/layers/dense.cu @@ -10,14 +10,8 @@ #include "dense.cuh" #include "matrix_math.cuh" -Layers::Dense::Dense( - int inputSize, - int outputSize, - Activation activation -) - : inputSize(inputSize), - outputSize(outputSize), - activation(activation) { +Layers::Dense::Dense(int inputSize, int outputSize, Activation activation) + : inputSize(inputSize), outputSize(outputSize), activation(activation) { // Allocate memory for weights and biases weights.resize(outputSize * inputSize); biases.resize(outputSize); @@ -52,7 +46,6 @@ void Layers::Dense::initializeBiases() { } void Layers::Dense::forward(const float* d_input, float* d_output) { - mat_vec_mul_kernel<<<1, outputSize>>>( d_weights, d_input, d_output, inputSize, outputSize ); @@ -63,15 +56,11 @@ void Layers::Dense::forward(const float* d_input, float* d_output) { switch (activation) { case SIGMOID: - sigmoid_kernel<<<1, outputSize>>>( - d_output, d_output, outputSize - ); + sigmoid_kernel<<<1, outputSize>>>(d_output, d_output, outputSize); break; case RELU: - relu_kernel<<<1, outputSize>>>( - d_output, d_output, outputSize - ); + relu_kernel<<<1, outputSize>>>(d_output, d_output, outputSize); break; default: @@ -92,26 +81,12 @@ void Layers::Dense::toCuda() { )); } -void Layers::Dense::setWeights( - const std::vector>& weights_input -) { - int numWeights = inputSize * outputSize; - - if (weights.size() != numWeights) { - std::cerr << "Invalid number of weights" << std::endl; - exit(EXIT_FAILURE); - } - - for (int i = 0; i < outputSize; ++i) { - for (int j = 0; j < inputSize; ++j) { - weights[i * inputSize + j] = weights_input[i][j]; - } - } - +void Layers::Dense::setWeights(const float* weights_input) { + std::copy(weights_input, weights_input + weights.size(), weights.begin()); toCuda(); } -void Layers::Dense::setBiases(const std::vector& biases_input) { - std::copy(biases_input.begin(), biases_input.end(), biases.begin()); +void Layers::Dense::setBiases(const float* biases_input) { + std::copy(biases_input, biases_input + biases.size(), biases.begin()); toCuda(); } \ No newline at end of file diff --git a/test/layers/test_conv2d.cu b/test/layers/test_conv2d.cu index 527069d..33ab730 100644 --- a/test/layers/test_conv2d.cu +++ b/test/layers/test_conv2d.cu @@ -16,7 +16,7 @@ class Conv2dTest : public ::testing::Test { int numFilters, Activation activation, std::vector& input, - std::vector& kernels, + float* kernels, float*& d_input, float*& d_output ) { @@ -26,7 +26,7 @@ class Conv2dTest : public ::testing::Test { activation ); - conv2d.setKernels(kernels); + conv2d.setWeights(kernels); // Allocate device memory cudaStatus = cudaMalloc( @@ -84,7 +84,7 @@ TEST_F(Conv2dTest, SimpleTest) { Layers::Conv2d conv2d = commonTestSetup( inputSize, inputChannels, kernelSize, stride, padding, numFilters, - activation, input, kernels, d_input, d_output + activation, input, kernels.data(), d_input, d_output ); int outputSize = (inputSize - kernelSize) / stride + 1; @@ -173,7 +173,7 @@ TEST_F(Conv2dTest, ComplexTest) { Layers::Conv2d conv2d = commonTestSetup( inputSize, inputChannels, kernelSize, stride, padding, numFilters, - activation, input, kernels, d_input, d_output + activation, input, kernels.data(), d_input, d_output ); EXPECT_EQ(inputSize, conv2d.outputSize); diff --git a/test/layers/test_dense.cu b/test/layers/test_dense.cu index b8c8147..e1b78d3 100644 --- a/test/layers/test_dense.cu +++ b/test/layers/test_dense.cu @@ -6,23 +6,20 @@ #include "activations.cuh" #include "dense.cuh" - -class DenseLayerTest : public::testing::Test { +class DenseLayerTest : public ::testing::Test { protected: Layers::Dense commonTestSetup( - int inputSize, - int outputSize, - std::vector& input, - std::vector>& weights, - std::vector& biases, - float*& d_input, - float*& d_output, - Activation activation + int inputSize, + int outputSize, + std::vector& input, + float* weights, + float* biases, + float*& d_input, + float*& d_output, + Activation activation ) { // Create Dense layer - Layers::Dense denseLayer( - inputSize, outputSize, activation - ); + Layers::Dense denseLayer(inputSize, outputSize, activation); // Set weights and biases denseLayer.setWeights(weights); @@ -37,11 +34,11 @@ class DenseLayerTest : public::testing::Test { // Copy input to device cudaStatus = cudaMemcpy( - d_input, input.data(), sizeof(float) * input.size(), cudaMemcpyHostToDevice + d_input, input.data(), sizeof(float) * input.size(), + cudaMemcpyHostToDevice ); EXPECT_EQ(cudaStatus, cudaSuccess); - return denseLayer; } @@ -51,7 +48,7 @@ class DenseLayerTest : public::testing::Test { cudaFree(d_output); } - cudaError_t cudaStatus; + cudaError_t cudaStatus; }; TEST_F(DenseLayerTest, Init) { @@ -60,9 +57,7 @@ TEST_F(DenseLayerTest, Init) { int inputSize = i; int outputSize = j; - Layers::Dense denseLayer( - inputSize, outputSize, SIGMOID - ); + Layers::Dense denseLayer(inputSize, outputSize, SIGMOID); } } } @@ -71,17 +66,19 @@ TEST_F(DenseLayerTest, setWeights) { int inputSize = 4; int outputSize = 5; - std::vector> weights = { - {0.5f, 1.0f, 0.2f, 0.8f}, - {1.2f, 0.3f, 1.5f, 0.4f}, - {0.7f, 1.8f, 0.9f, 0.1f}, - {0.4f, 2.0f, 0.6f, 1.1f}, - {1.3f, 0.5f, 0.0f, 1.7f} + // clang-format off + std::vector weights = { + 0.5f, 1.0f, 0.2f, 0.8f, + 1.2f, 0.3f, 1.5f, 0.4f, + 0.7f, 1.8f, 0.9f, 0.1f, + 0.4f, 2.0f, 0.6f, 1.1f, + 1.3f, 0.5f, 0.0f, 1.7f }; + // clang-format on Layers::Dense denseLayer(inputSize, outputSize, SIGMOID); - denseLayer.setWeights(weights); + denseLayer.setWeights(weights.data()); } TEST_F(DenseLayerTest, ForwardUnitWeightMatrixLinear) { @@ -90,13 +87,11 @@ TEST_F(DenseLayerTest, ForwardUnitWeightMatrixLinear) { std::vector input = {1.0f, 2.0f, 3.0f}; - std::vector> weights( - inputSize, std::vector(outputSize, 0.0f) - ); + std::vector weights(outputSize * inputSize, 0.0f); for (int i = 0; i < inputSize; ++i) { for (int j = 0; j < outputSize; ++j) { if (i == j) { - weights[i][j] = 1.0f; + weights[i * outputSize + j] = 1.0f; } } } @@ -106,13 +101,15 @@ TEST_F(DenseLayerTest, ForwardUnitWeightMatrixLinear) { float* d_output; Layers::Dense denseLayer = commonTestSetup( - inputSize, outputSize, input, weights, biases, d_input, d_output, LINEAR + inputSize, outputSize, input, weights.data(), biases.data(), d_input, + d_output, LINEAR ); denseLayer.forward(d_input, d_output); std::vector output(outputSize); cudaStatus = cudaMemcpy( - output.data(), d_output, sizeof(float) * outputSize, cudaMemcpyDeviceToHost + output.data(), d_output, sizeof(float) * outputSize, + cudaMemcpyDeviceToHost ); EXPECT_EQ(cudaStatus, cudaSuccess); @@ -130,26 +127,30 @@ TEST_F(DenseLayerTest, ForwardRandomWeightMatrixRelu) { std::vector input = {1.0f, 2.0f, 3.0f, 4.0f, -5.0f}; - std::vector> weights = { - {0.5f, 1.2f, 0.7f, 0.4f, 1.3f}, - {1.0f, 0.3f, 1.8f, 2.0f, 0.5f}, - {0.2f, 1.5f, 0.9f, 0.6f, 0.0f}, - {0.8f, 0.4f, 0.1f, 1.1f, 1.7f} + // clang-format off + std::vector weights = { + 0.5f, 1.2f, 0.7f, 0.4f, + 1.3f, 1.0f, 0.3f, 1.8f, + 2.0f, 0.5f, 0.2f, 1.5f, + 0.9f, 0.6f, 0.0f, 0.8f, + 0.4f, 0.1f, 1.1f, 1.7f }; std::vector biases = {0.2f, 0.5f, 0.7f, -1.1f}; + // clang-format on float* d_input; float* d_output; Layers::Dense denseLayer = commonTestSetup( - inputSize, outputSize, input, weights, biases, d_input, d_output, RELU + inputSize, outputSize, input, weights.data(), biases.data(), d_input, d_output, RELU ); denseLayer.forward(d_input, d_output); std::vector output(outputSize); cudaStatus = cudaMemcpy( - output.data(), d_output, sizeof(float) * outputSize, cudaMemcpyDeviceToHost + output.data(), d_output, sizeof(float) * outputSize, + cudaMemcpyDeviceToHost ); EXPECT_EQ(cudaStatus, cudaSuccess); @@ -170,21 +171,22 @@ TEST_F(DenseLayerTest, ForwardRandomWeightMatrixSigmoid) { int inputSize = 5; int outputSize = 4; + // clang-format off std::vector input = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f}; - - std::vector> weights = { - {0.8f, 0.7f, 0.7f, 0.3f, 0.8f}, - {0.1f, 0.4f, 0.8f, 0.0f, 0.2f}, - {0.2f, 0.5f, 0.7f, 0.3f, 0.0f}, - {0.1f, 0.7f, 0.6f, 1.0f, 0.4f} + std::vector weights = { + 0.8f, 0.7f, 0.7f, 0.3f, 0.8f, + 0.1f, 0.4f, 0.8f, 0.0f, 0.2f, + 0.2f, 0.5f, 0.7f, 0.3f, 0.0f, + 0.1f, 0.7f, 0.6f, 1.0f, 0.4f }; std::vector biases = {0.1f, 0.2f, 0.3f, 0.4f}; + // clang-format on float* d_input; float* d_output; Layers::Dense denseLayer = commonTestSetup( - inputSize, outputSize, input, weights, biases, d_input, d_output, + inputSize, outputSize, input, weights.data(), biases.data(), d_input, d_output, SIGMOID ); @@ -192,7 +194,8 @@ TEST_F(DenseLayerTest, ForwardRandomWeightMatrixSigmoid) { std::vector output(outputSize); cudaStatus = cudaMemcpy( - output.data(), d_output, sizeof(float) * outputSize, cudaMemcpyDeviceToHost + output.data(), d_output, sizeof(float) * outputSize, + cudaMemcpyDeviceToHost ); EXPECT_EQ(cudaStatus, cudaSuccess);