diff --git a/include/backend.hpp b/include/backend.hpp index 389f52b..1c69e51 100644 --- a/include/backend.hpp +++ b/include/backend.hpp @@ -32,12 +32,12 @@ class Backend { ) = 0; virtual CUDANet::Tensor& dense( - CUDANet::Tensor& weights, - CUDANet::Tensor& biases, - CUDANet::Tensor& input, + const CUDANet::Tensor& weights, + const CUDANet::Tensor& biases, + const CUDANet::Tensor& input, CUDANet::Tensor& output, - size_t input_size, - size_t output_size + const size_t input_size, + const size_t output_size ) = 0; }; diff --git a/include/backend/cuda.cuh b/include/backend/cuda.cuh index 1be8378..4489e3d 100644 --- a/include/backend/cuda.cuh +++ b/include/backend/cuda.cuh @@ -29,12 +29,12 @@ class CUDA : public Backend { ) override; CUDANet::Tensor& dense( - CUDANet::Tensor& weights, - CUDANet::Tensor& biases, - CUDANet::Tensor& input, + const CUDANet::Tensor& weights, + const CUDANet::Tensor& biases, + const CUDANet::Tensor& input, CUDANet::Tensor& output, - size_t input_size, - size_t output_size + const size_t input_size, + const size_t output_size ) override; }; diff --git a/include/layer.hpp b/include/layer.hpp index 5b02346..77d50de 100644 --- a/include/layer.hpp +++ b/include/layer.hpp @@ -20,7 +20,7 @@ class Layer { virtual ~Layer(){}; - virtual CUDANet::Tensor& forward(CUDANet::Tensor &input) = 0; + virtual CUDANet::Tensor& forward(const CUDANet::Tensor &input) = 0; virtual CUDANet::Shape input_shape() = 0; diff --git a/include/layers/dense.hpp b/include/layers/dense.hpp index 6e5356c..a74a4ab 100644 --- a/include/layers/dense.hpp +++ b/include/layers/dense.hpp @@ -18,7 +18,7 @@ class Dense : public Layer { ~Dense(); - CUDANet::Tensor& forward(CUDANet::Tensor &input) override; + CUDANet::Tensor& forward(const CUDANet::Tensor &input) override; CUDANet::Shape input_shape() override; diff --git a/include/tensor.hpp b/include/tensor.hpp index 5e074b9..b122fb0 100644 --- a/include/tensor.hpp +++ b/include/tensor.hpp @@ -22,6 +22,12 @@ public: Tensor() = default; Tensor(Shape shape, DType dtype, CUDANet::Backend* backend); + + Tensor(Tensor&& other) noexcept; + Tensor& operator=(Tensor&& other) noexcept; + Tensor(const Tensor&) = delete; + Tensor& operator=(const Tensor&) = delete; + ~Tensor(); size_t size() const; diff --git a/src/backends/cuda/layer_ops.cu b/src/backends/cuda/layer_ops.cu index d08965e..9d4fc0e 100644 --- a/src/backends/cuda/layer_ops.cu +++ b/src/backends/cuda/layer_ops.cu @@ -1,25 +1,29 @@ #include "backend/cuda.cuh" -#include "utils/cuda_helper.cuh" #include "kernels/activation_functions.cuh" #include "kernels/matmul.cuh" +#include "utils/cuda_helper.cuh" using namespace CUDANet::Backend; -void CUDA::relu(Tensor &tensor) { +void CUDA::relu(Tensor& tensor) { int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE; - Kernels::relu<<>>(tensor.data(), tensor.data(), tensor.numel()); + Kernels::relu<<>>( + tensor.data(), tensor.data(), tensor.numel() + ); CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaDeviceSynchronize()); } -void CUDA::sigmoid(Tensor &tensor) { +void CUDA::sigmoid(Tensor& tensor) { int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE; - Kernels::sigmoid<<>>(tensor.data(), tensor.data(), tensor.numel()); + Kernels::sigmoid<<>>( + tensor.data(), tensor.data(), tensor.numel() + ); CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaDeviceSynchronize()); } -void CUDA::softmax(Tensor &tensor, Tensor &temp_max, Tensor &temp_sum) { +void CUDA::softmax(Tensor& tensor, Tensor& temp_max, Tensor& temp_sum) { int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE; // Find max value @@ -27,7 +31,8 @@ void CUDA::softmax(Tensor &tensor, Tensor &temp_max, Tensor &temp_sum) { // Subtract max value to improve numerical stability Kernels::vec_scalar_sub<<>>( - tensor.data(), tensor.data(), temp_max.data(), tensor.numel() + tensor.data(), tensor.data(), temp_max.data(), + tensor.numel() ); CUDA_CHECK(cudaGetLastError()); @@ -36,30 +41,39 @@ void CUDA::softmax(Tensor &tensor, Tensor &temp_max, Tensor &temp_sum) { tensor.data(), tensor.data(), tensor.numel() ); CUDA_CHECK(cudaGetLastError()); - + // Find sum sum(tensor, temp_sum); Kernels::vec_scalar_div<<>>( - tensor.data(), tensor.data(), temp_sum.data(), tensor.numel() + tensor.data(), tensor.data(), temp_sum.data(), + tensor.numel() ); CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaDeviceSynchronize()); } -CUDANet::Tensor& CUDA::dense(CUDANet::Tensor &weights, CUDANet::Tensor &biases, CUDANet::Tensor &input, CUDANet::Tensor &output, size_t input_size, size_t output_size) { - +CUDANet::Tensor& CUDA::dense( + const CUDANet::Tensor& weights, + const CUDANet::Tensor& biases, + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + const size_t input_size, + const size_t output_size +) { auto forwardGridSize = (std::max(input_size, output_size) + BLOCK_SIZE - 1) / BLOCK_SIZE; auto biasGridSize = (output_size + BLOCK_SIZE - 1) / BLOCK_SIZE; Kernels::mat_vec_mul<<>>( - weights.data(), input.data(), output.data(), input_size, output_size + weights.data(), input.data(), output.data(), + input_size, output_size ); CUDA_CHECK(cudaGetLastError()); Kernels::vec_vec_add<<>>( - biases.data(), output.data(), output.data(), output_size + biases.data(), output.data(), output.data(), + output_size ); CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaDeviceSynchronize()); diff --git a/src/layers/dense.cpp b/src/layers/dense.cpp index a164bcb..bdaa5ee 100644 --- a/src/layers/dense.cpp +++ b/src/layers/dense.cpp @@ -1,34 +1,41 @@ +#include "dense.hpp" + #include #include -#include "dense.hpp" - using namespace CUDANet::Layers; -Dense::Dense(CUDANet::Backend *backend, CUDANet::Shape input_shape, CUDANet::Shape output_shape) - : backend(backend), in_shape(input_shape), out_shape(output_shape) { +Dense::Dense(CUDANet::Backend* backend, CUDANet::Shape in, CUDANet::Shape out) + : backend(backend), + in_shape(in), + out_shape(out), + weights( + CUDANet::Tensor{{in[0] * out[0]}, CUDANet::DType::FLOAT32, backend} + ), + biases(CUDANet::Tensor({out[0]}, CUDANet::DType::FLOAT32, backend)), + output(CUDANet::Tensor({out[0]}, CUDANet::DType::FLOAT32, backend)) { // Allocate memory for weights and biases - if (input_shape.size() != 1) { - throw std::runtime_error(std::format("Invalid shape. Expected [1], got {}", input_shape)); - } - - if (output_shape.size() != 1) { - throw std::runtime_error(std::format("Invalid shape. Expected [1], got {}", output_shape)); + if (in.size() != 1) { + throw std::runtime_error( + std::format("Invalid shape. Expected [1], got {}", in) + ); } - auto input_len = input_shape[0]; - auto output_len = output_shape[0]; + if (out.size() != 1) { + throw std::runtime_error( + std::format("Invalid shape. Expected [1], got {}", out) + ); + } - auto weights = CUDANet::Tensor{Shape(input_len * output_len), CUDANet::DType::FLOAT32, backend}; - auto biases = CUDANet::Tensor(Shape(output_len), CUDANet::DType::FLOAT32, backend); - auto output = CUDANet::Tensor(Shape(output_len), CUDANet::DType::FLOAT32, backend); + auto input_len = in[0]; + auto output_len = out[0]; weights.zero(); biases.zero(); } -CUDANet::Tensor& Dense::forward(CUDANet::Tensor &input) { +CUDANet::Tensor& Dense::forward(const CUDANet::Tensor& input) { backend->dense(weights, biases, input, output, in_shape[0], out_shape[0]); return output; } @@ -49,7 +56,7 @@ size_t Dense::output_size() { return out_shape[0]; }; -void Dense::set_weights(void *input) { +void Dense::set_weights(void* input) { weights.set_data(static_cast(input)); } @@ -57,7 +64,7 @@ CUDANet::Tensor& Dense::get_weights() { return weights; } -void Dense::set_biases(void *input) { +void Dense::set_biases(void* input) { biases.set_data(static_cast(input)); } diff --git a/src/tensor.cpp b/src/tensor.cpp index 88325cc..bed6f8c 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -6,6 +6,11 @@ using namespace CUDANet; Tensor::Tensor(Shape shape, DType dtype, Backend* backend) : shape(shape), dtype(dtype), backend(backend), d_ptr(nullptr) { + + if (shape.empty()) { + throw std::runtime_error("Tensor shape cannot be empty"); + } + // Count total elements size_t count = 1; for (const auto& dim : shape) { @@ -28,6 +33,40 @@ Tensor::Tensor(Shape shape, DType dtype, Backend* backend) d_ptr = backend->allocate(total_size); } +Tensor::Tensor(Tensor&& other) noexcept + : shape(std::move(other.shape)), + dtype(other.dtype), + total_elms(other.total_elms), + total_size(other.total_size), + backend(other.backend), + d_ptr(other.d_ptr) +{ + other.d_ptr = nullptr; + other.backend = nullptr; +} + +Tensor& Tensor::operator=(Tensor&& other) noexcept { + if (this != &other) { + // Clean up our current resources + if (d_ptr != nullptr && backend != nullptr) { + backend->deallocate(d_ptr); + } + + // Steal other's resources + shape = std::move(other.shape); + dtype = other.dtype; + total_elms = other.total_elms; + total_size = other.total_size; + backend = other.backend; + d_ptr = other.d_ptr; + + // Leave other in valid but empty state + other.d_ptr = nullptr; + other.backend = nullptr; + } + return *this; +} + Tensor::~Tensor() { backend->deallocate(d_ptr); d_ptr = nullptr; @@ -57,5 +96,5 @@ void Tensor::zero() { template void Tensor::set_data(T *data) { - backend->copy_to_device(*this, data, total_size) + backend->copy_to_device(*this, data, total_size); } \ No newline at end of file