From 7f203b8947aaf0dd0233457d50c2306a65d63d61 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Tue, 18 Nov 2025 21:12:47 +0100 Subject: [PATCH] WIP Migrate Dense layer --- include/backend.hpp | 44 ++++++++++----- include/backend/cuda.cuh | 33 ++++++++--- include/layer.hpp | 4 +- include/layers/activation.hpp | 18 +++--- include/layers/dense.hpp | 45 ++++----------- src/backends/cuda/layer_ops.cu | 20 +++++++ src/backends/cuda/layers/activation.cu | 77 -------------------------- src/backends/cuda/layers/dense.cu | 69 ----------------------- src/backends/cuda/tensor_ops.cu | 4 ++ src/layers/activation.cpp | 4 +- src/layers/dense.cpp | 14 ++++- src/{model => }/model.cpp | 0 src/{model => }/module.cpp | 0 src/{backends => }/tensor.cpp | 5 ++ 14 files changed, 116 insertions(+), 221 deletions(-) delete mode 100644 src/backends/cuda/layers/activation.cu delete mode 100644 src/backends/cuda/layers/dense.cu rename src/{model => }/model.cpp (100%) rename src/{model => }/module.cpp (100%) rename src/{backends => }/tensor.cpp (91%) diff --git a/include/backend.hpp b/include/backend.hpp index e8d397a..389f52b 100644 --- a/include/backend.hpp +++ b/include/backend.hpp @@ -4,27 +4,41 @@ #include "tensor.hpp" -namespace CUDANet -{ - -class Backend -{ -public: +namespace CUDANet { +class Backend { + public: // Memory management virtual void* allocate(size_t bytes) = 0; - virtual void deallocate(void* ptr) = 0; + virtual void deallocate(void* ptr) = 0; // Tensor ops - virtual void print(const CUDANet::Tensor &input) = 0; - virtual void zero(CUDANet::Tensor &input) = 0; - virtual void sum(const CUDANet::Tensor &input, CUDANet::Tensor &sum) = 0; - virtual void max(const CUDANet::Tensor &input, CUDANet::Tensor &max) = 0; + virtual void print(const CUDANet::Tensor& input) = 0; + virtual void zero(CUDANet::Tensor& input) = 0; + + virtual void + copy_to_device(CUDANet::Tensor& tensor, void* data, size_t size) = 0; + + virtual void sum(const CUDANet::Tensor& input, CUDANet::Tensor& sum) = 0; + virtual void max(const CUDANet::Tensor& input, CUDANet::Tensor& max) = 0; // Layer ops - virtual void relu(CUDANet::Tensor &tensor) = 0; - virtual void sigmoid(CUDANet::Tensor &tensor) = 0; - virtual void softmax(CUDANet::Tensor &tensor, CUDANet::Tensor &temp_max, CUDANet::Tensor &temp_sum) = 0; + virtual void relu(CUDANet::Tensor& tensor) = 0; + virtual void sigmoid(CUDANet::Tensor& tensor) = 0; + virtual void softmax( + CUDANet::Tensor& tensor, + CUDANet::Tensor& temp_max, + CUDANet::Tensor& temp_sum + ) = 0; + + virtual CUDANet::Tensor& dense( + CUDANet::Tensor& weights, + CUDANet::Tensor& biases, + CUDANet::Tensor& input, + CUDANet::Tensor& output, + size_t input_size, + size_t output_size + ) = 0; }; -} // namespace CUDANet::Backend \ No newline at end of file +} // namespace CUDANet \ No newline at end of file diff --git a/include/backend/cuda.cuh b/include/backend/cuda.cuh index 5045e28..1be8378 100644 --- a/include/backend/cuda.cuh +++ b/include/backend/cuda.cuh @@ -6,21 +6,36 @@ namespace CUDANet::Backend { class CUDA : public Backend { -public: + public: // Memory management void* allocate(size_t bytes) override; - void deallocate(void* ptr) override; + void deallocate(void* ptr) override; // Tensor ops - void print(const CUDANet::Tensor &input) override; - void zero(CUDANet::Tensor &input) override; - void sum(const CUDANet::Tensor &input, CUDANet::Tensor &sum) override; - void max(const CUDANet::Tensor &input, CUDANet::Tensor &max) override; + void print(const CUDANet::Tensor& input) override; + void zero(CUDANet::Tensor& input) override; + void + copy_to_device(CUDANet::Tensor& tensor, void* data, size_t size) override; + void sum(const CUDANet::Tensor& input, CUDANet::Tensor& sum) override; + void max(const CUDANet::Tensor& input, CUDANet::Tensor& max) override; // Layer ops - void relu(CUDANet::Tensor &tensor) override; - void sigmoid(CUDANet::Tensor &tensor) override; - void softmax(CUDANet::Tensor &tensor, CUDANet::Tensor &temp_max, CUDANet::Tensor &temp_sum) override; + void relu(CUDANet::Tensor& tensor) override; + void sigmoid(CUDANet::Tensor& tensor) override; + void softmax( + CUDANet::Tensor& tensor, + CUDANet::Tensor& temp_max, + CUDANet::Tensor& temp_sum + ) override; + + CUDANet::Tensor& dense( + CUDANet::Tensor& weights, + CUDANet::Tensor& biases, + CUDANet::Tensor& input, + CUDANet::Tensor& output, + size_t input_size, + size_t output_size + ) override; }; } // namespace CUDANet::Backend \ No newline at end of file diff --git a/include/layer.hpp b/include/layer.hpp index b50da8c..5b02346 100644 --- a/include/layer.hpp +++ b/include/layer.hpp @@ -30,11 +30,11 @@ class Layer { virtual size_t output_size() = 0; - virtual void set_weights(CUDANet::Tensor &input) = 0; + virtual void set_weights(void *input) = 0; virtual CUDANet::Tensor& get_weights() = 0; - virtual void set_biases(CUDANet::Tensor &input) = 0; + virtual void set_biases(void *input) = 0; virtual CUDANet::Tensor& get_biases() = 0; }; diff --git a/include/layers/activation.hpp b/include/layers/activation.hpp index 8ad52b2..323ec22 100644 --- a/include/layers/activation.hpp +++ b/include/layers/activation.hpp @@ -29,23 +29,23 @@ class Activation : public Layer { ~Activation() = default; - CUDANet::Tensor& forward(CUDANet::Tensor &input); + CUDANet::Tensor& forward(CUDANet::Tensor &input) override; - CUDANet::Shape input_shape(); + CUDANet::Shape input_shape() override; - CUDANet::Shape output_shape(); + CUDANet::Shape output_shape() override; - size_t input_size(); + size_t input_size() override; - size_t output_size(); + size_t output_size() override; - void set_weights(CUDANet::Tensor &input); + void set_weights(void *input) override; - CUDANet::Tensor& get_weights(); + CUDANet::Tensor& get_weights() override; - void set_biases(CUDANet::Tensor &input); + void set_biases(void *input) override; - CUDANet::Tensor& get_biases(); + CUDANet::Tensor& get_biases() override; private: diff --git a/include/layers/dense.hpp b/include/layers/dense.hpp index d6bee21..6e5356c 100644 --- a/include/layers/dense.hpp +++ b/include/layers/dense.hpp @@ -18,23 +18,23 @@ class Dense : public Layer { ~Dense(); - CUDANet::Tensor& forward(CUDANet::Tensor &input); + CUDANet::Tensor& forward(CUDANet::Tensor &input) override; - CUDANet::Shape input_shape(); + CUDANet::Shape input_shape() override; - CUDANet::Shape output_shape(); + CUDANet::Shape output_shape() override; - size_t input_size(); + size_t input_size() override; - size_t output_size(); + size_t output_size() override; - void set_weights(CUDANet::Tensor &input); + void set_weights(void *input) override; - CUDANet::Tensor& get_weights(); + CUDANet::Tensor& get_weights() override; - void set_biases(CUDANet::Tensor &input); + void set_biases(void *input) override; - CUDANet::Tensor& get_biases(); + CUDANet::Tensor& get_biases() override; private: CUDANet::Backend *backend; @@ -45,32 +45,7 @@ class Dense : public Layer { CUDANet::Tensor weights; CUDANet::Tensor biases; - - void init_weights(); - void init_biases(); - -// #ifdef USE_CUDA -// float* d_output; - -// float* d_weights; -// float* d_biases; - -// // Precompute kernel launch parameters -// int forwardGridSize; -// int biasGridSize; - -// /** -// * @brief Copy the weights and biases to the device -// * -// */ -// void toCuda(); - -// void initCUDA(); -// void delCUDA(); - -// float* forwardCUDA(const float* d_input); -// #endif - + CUDANet::Tensor output; }; } // namespace CUDANet::Layers diff --git a/src/backends/cuda/layer_ops.cu b/src/backends/cuda/layer_ops.cu index 252d403..d08965e 100644 --- a/src/backends/cuda/layer_ops.cu +++ b/src/backends/cuda/layer_ops.cu @@ -45,4 +45,24 @@ void CUDA::softmax(Tensor &tensor, Tensor &temp_max, Tensor &temp_sum) { ); CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaDeviceSynchronize()); +} + +CUDANet::Tensor& CUDA::dense(CUDANet::Tensor &weights, CUDANet::Tensor &biases, CUDANet::Tensor &input, CUDANet::Tensor &output, size_t input_size, size_t output_size) { + + auto forwardGridSize = + (std::max(input_size, output_size) + BLOCK_SIZE - 1) / BLOCK_SIZE; + auto biasGridSize = (output_size + BLOCK_SIZE - 1) / BLOCK_SIZE; + + Kernels::mat_vec_mul<<>>( + weights.data(), input.data(), output.data(), input_size, output_size + ); + CUDA_CHECK(cudaGetLastError()); + + Kernels::vec_vec_add<<>>( + biases.data(), output.data(), output.data(), output_size + ); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + + return output; } \ No newline at end of file diff --git a/src/backends/cuda/layers/activation.cu b/src/backends/cuda/layers/activation.cu deleted file mode 100644 index 9b9dbca..0000000 --- a/src/backends/cuda/layers/activation.cu +++ /dev/null @@ -1,77 +0,0 @@ -#include - -#include "activation.hpp" -#include "activation_functions.cuh" -#include "cuda_helper.cuh" -#include "matmul.cuh" -#include "vector.cuh" - -using namespace CUDANet::Layers; - -void Activation::initCUDA() { - if (activationType == SOFTMAX) { - d_softmax_sum = nullptr; - CUDA_CHECK(cudaMalloc((void**)&d_softmax_sum, sizeof(float) * length)); - - d_max = nullptr; - CUDA_CHECK(cudaMalloc((void**)&d_max, sizeof(float) * length)); - } - - gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE; -} - -void Activation::delCUDA() { - if (activationType == SOFTMAX) { - CUDA_CHECK(cudaFree(d_softmax_sum)); - CUDA_CHECK(cudaFree(d_max)); - } -} - -void Activation::activateCUDA(float* d_input) { - - // float sum = 0.0f; - - switch (activationType) { - case SIGMOID: - Kernels::sigmoid<<>>( - d_input, d_input, length - ); - CUDA_CHECK(cudaGetLastError()); - break; - - case RELU: - Kernels::relu<<>>(d_input, d_input, length); - CUDA_CHECK(cudaGetLastError()); - break; - case SOFTMAX: - - // Find max value - Utils::max(d_input, d_max, length); - - // Subtract max value to improve numerical stability - Kernels::vec_scalar_sub<<>>( - d_input, d_input, &d_max[0], length - ); - CUDA_CHECK(cudaGetLastError()); - - // Compute exponentials - Kernels::vec_exp<<>>( - d_input, d_input, length - ); - CUDA_CHECK(cudaGetLastError()); - - // Find sum - Utils::sum(d_input, d_softmax_sum, length); - - Kernels::vec_scalar_div<<>>( - d_input, d_input, &d_softmax_sum[0], length - ); - CUDA_CHECK(cudaGetLastError()); - break; - - default: - break; - } - - CUDA_CHECK(cudaDeviceSynchronize()); -} diff --git a/src/backends/cuda/layers/dense.cu b/src/backends/cuda/layers/dense.cu deleted file mode 100644 index f334376..0000000 --- a/src/backends/cuda/layers/dense.cu +++ /dev/null @@ -1,69 +0,0 @@ -#include - -#include -#include -#include -#include - -#include "vector.cuh" -#include "activation.hpp" -#include "cuda_helper.cuh" -#include "dense.hpp" -#include "matmul.cuh" - -using namespace CUDANet::Layers; - -void Dense::initCUDA() { - d_output = nullptr; - - CUDA_CHECK(cudaMalloc((void**)&d_output, sizeof(float) * outputSize)); - - d_weights = nullptr; - d_biases = nullptr; - - // Allocate GPU memory for weights and biases - CUDA_CHECK( - cudaMalloc((void**)&d_weights, sizeof(float) * inputSize * outputSize) - ); - CUDA_CHECK(cudaMalloc((void**)&d_biases, sizeof(float) * outputSize)); - toCuda(); - - // Calculate block and grid sizes - forwardGridSize = - (std::max(inputSize, outputSize) + BLOCK_SIZE - 1) / BLOCK_SIZE; - biasGridSize = (outputSize + BLOCK_SIZE - 1) / BLOCK_SIZE; -} - -void Dense::delCUDA() { - cudaFree(d_output); - cudaFree(d_weights); - cudaFree(d_biases); -} - -void Dense::toCuda() { - CUDA_CHECK(cudaMemcpy( - d_weights, weights.data(), sizeof(float) * inputSize * outputSize, - cudaMemcpyHostToDevice - )); - CUDA_CHECK(cudaMemcpy( - d_biases, biases.data(), sizeof(float) * outputSize, - cudaMemcpyHostToDevice - )); -} - -float* Dense::forwardCUDA(const float* d_input) { - Kernels::mat_vec_mul<<>>( - d_weights, d_input, d_output, inputSize, outputSize - ); - CUDA_CHECK(cudaGetLastError()); - - Kernels::vec_vec_add<<>>( - d_biases, d_output, d_output, outputSize - ); - CUDA_CHECK(cudaGetLastError()); - - activation->activate(d_output); - CUDA_CHECK(cudaDeviceSynchronize()); - - return d_output; -} diff --git a/src/backends/cuda/tensor_ops.cu b/src/backends/cuda/tensor_ops.cu index ef9e256..41ba2ba 100644 --- a/src/backends/cuda/tensor_ops.cu +++ b/src/backends/cuda/tensor_ops.cu @@ -26,6 +26,10 @@ void CUDA::zero(CUDANet::Tensor &input) { CUDA_CHECK(cudaMemset(input.data(), 0, sizeof(float) * input.numel())); } +void CUDA::copy_to_device(CUDANet::Tensor &tensor, void *data, size_t size) { + CUDA_CHECK(cudaMemcpy(tensor.data(), data, size, cudaMemcpyHostToDevice)); +} + void CUDA::sum(const CUDANet::Tensor &input, CUDANet::Tensor &sum) { auto length = input.numel(); const int gridSize = ( + BLOCK_SIZE - 1) / BLOCK_SIZE; diff --git a/src/layers/activation.cpp b/src/layers/activation.cpp index 67e468d..def0c67 100644 --- a/src/layers/activation.cpp +++ b/src/layers/activation.cpp @@ -57,10 +57,10 @@ size_t Activation::output_size() { return shape[0]; } -void Activation::set_weights(CUDANet::Tensor &input) {} +void Activation::set_weights(void *input) {} CUDANet::Tensor& Activation::get_weights() {} -void Activation::set_biases(CUDANet::Tensor &input) {} +void Activation::set_biases(void *input) {} CUDANet::Tensor& Activation::get_biases() {} \ No newline at end of file diff --git a/src/layers/dense.cpp b/src/layers/dense.cpp index 245281c..a164bcb 100644 --- a/src/layers/dense.cpp +++ b/src/layers/dense.cpp @@ -22,12 +22,16 @@ Dense::Dense(CUDANet::Backend *backend, CUDANet::Shape input_shape, CUDANet::Sha auto weights = CUDANet::Tensor{Shape(input_len * output_len), CUDANet::DType::FLOAT32, backend}; auto biases = CUDANet::Tensor(Shape(output_len), CUDANet::DType::FLOAT32, backend); + auto output = CUDANet::Tensor(Shape(output_len), CUDANet::DType::FLOAT32, backend); weights.zero(); biases.zero(); } -CUDANet::Tensor& Dense::forward(CUDANet::Tensor &input); +CUDANet::Tensor& Dense::forward(CUDANet::Tensor &input) { + backend->dense(weights, biases, input, output, in_shape[0], out_shape[0]); + return output; +} CUDANet::Shape Dense::input_shape() { return in_shape; @@ -45,13 +49,17 @@ size_t Dense::output_size() { return out_shape[0]; }; -void Dense::set_weights(CUDANet::Tensor &input); +void Dense::set_weights(void *input) { + weights.set_data(static_cast(input)); +} CUDANet::Tensor& Dense::get_weights() { return weights; } -void Dense::set_biases(CUDANet::Tensor &input); +void Dense::set_biases(void *input) { + biases.set_data(static_cast(input)); +} CUDANet::Tensor& Dense::get_biases() { return biases; diff --git a/src/model/model.cpp b/src/model.cpp similarity index 100% rename from src/model/model.cpp rename to src/model.cpp diff --git a/src/model/module.cpp b/src/module.cpp similarity index 100% rename from src/model/module.cpp rename to src/module.cpp diff --git a/src/backends/tensor.cpp b/src/tensor.cpp similarity index 91% rename from src/backends/tensor.cpp rename to src/tensor.cpp index f15a7e9..88325cc 100644 --- a/src/backends/tensor.cpp +++ b/src/tensor.cpp @@ -54,3 +54,8 @@ T* Tensor::data() { void Tensor::zero() { backend->zero(*this); } + +template +void Tensor::set_data(T *data) { + backend->copy_to_device(*this, data, total_size) +} \ No newline at end of file