diff --git a/include/backend.hpp b/include/backend.hpp index 6a41d61..30c96a4 100644 --- a/include/backend.hpp +++ b/include/backend.hpp @@ -1,8 +1,10 @@ #pragma once #include +#include #include "shape.hpp" +#include "tensor.hpp" namespace CUDANet { @@ -22,7 +24,14 @@ class BackendFactory { }; class Backend { + protected: + std::optional default_dtype; public: + + virtual bool supports_dtype(DType dtype) const = 0; + virtual void set_default_dtype(DType dtype) = 0; + virtual DType get_default_dtype() const = 0; + // Memory management virtual void* allocate(size_t bytes) = 0; virtual void deallocate(void* ptr) = 0; diff --git a/include/backend/cuda/cuda.cuh b/include/backend/cuda/cuda.cuh index 10d2979..db548bc 100644 --- a/include/backend/cuda/cuda.cuh +++ b/include/backend/cuda/cuda.cuh @@ -1,6 +1,7 @@ #pragma once #include +#include #include "backend.hpp" #include "tensor.hpp" @@ -29,9 +30,14 @@ namespace CUDANet::Backends { class CUDA : public Backend { private: int device_id; + std::set supported_dtypes; public: CUDA(const BackendConfig& config); + bool supports_dtype(DType dtype) const override; + void set_default_dtype(DType dtype) override; + DType get_default_dtype() const override; + static bool is_cuda_available(); void initialize(); diff --git a/include/shape.hpp b/include/shape.hpp index f54a1cf..b0c97b7 100644 --- a/include/shape.hpp +++ b/include/shape.hpp @@ -66,6 +66,12 @@ struct Shape { __host__ bool operator!=(const Shape& other) const { return !(*this == other); } + + __host__ __device__ bool empty() const { + return ndim == 0; + } + + }; std::string format_shape(const Shape& shape) { diff --git a/include/tensor.hpp b/include/tensor.hpp index ab2fcf7..fb22bec 100644 --- a/include/tensor.hpp +++ b/include/tensor.hpp @@ -16,11 +16,14 @@ enum class DType // INT32, // Not implemented yet }; +size_t dtype_size(DType dtype); + class Tensor { public: Tensor() = default; + Tensor(Shape shape, CUDANet::Backend* backend); Tensor(Shape shape, DType dtype, CUDANet::Backend* backend); Tensor(Tensor&& other) noexcept; @@ -30,6 +33,8 @@ public: ~Tensor(); + DType get_dtype(); + size_t size() const; size_t numel() const; diff --git a/src/backends/cuda/cuda.cu b/src/backends/cuda/cuda.cu index 8d51d24..c2094f2 100644 --- a/src/backends/cuda/cuda.cu +++ b/src/backends/cuda/cuda.cu @@ -5,12 +5,15 @@ #include #include "backend/cuda/cuda.cuh" +#include "tensor.hpp" using namespace CUDANet::Backends; CUDA::CUDA(const BackendConfig& config) { device_id = config.device_id < 0 ? 0 : config.device_id; + supported_dtypes = {DType::FLOAT32}; + default_dtype = DType::FLOAT32; initialize(); } @@ -41,6 +44,28 @@ void CUDA::initialize() { std::printf("Using CUDA device %d: %s\n", device_id, deviceProp.name); } +bool CUDA::supports_dtype(DType dtype) const { + return supported_dtypes.contains(dtype); +} + +void CUDA::set_default_dtype(DType dtype) { + if (!supported_dtypes.contains(dtype)) { + throw std::runtime_error("Unsupported dtype"); + } + + default_dtype = dtype; +} + +CUDANet::DType CUDA::get_default_dtype() const { + if (default_dtype) { + return default_dtype.value(); + } + + const_cast(this)->default_dtype = DType::FLOAT32; + return DType::FLOAT32; +} + + void* CUDA::allocate(size_t bytes) { void* d_ptr = nullptr; CUDA_CHECK(cudaMalloc(&d_ptr, bytes)); diff --git a/src/tensor.cpp b/src/tensor.cpp index 5750d49..fc42057 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -1,20 +1,27 @@ -#include - #include "tensor.hpp" +#include + using namespace CUDANet; +Tensor::Tensor(Shape shape, CUDANet::Backend* backend) + : Tensor(shape, backend->get_default_dtype(), backend) {} + Tensor::Tensor(Shape shape, DType dtype, Backend* backend) : shape(shape), dtype(dtype), backend(backend), d_ptr(nullptr) { - if (shape.empty()) { throw std::runtime_error("Tensor shape cannot be empty"); } - + + // Check if backend supports DType + if (!backend->supports_dtype(dtype)) { + throw std::runtime_error("Unsupported DType"); + } + // Count total elements size_t count = 1; - for (const auto& dim : shape) { - count *= dim; + for (size_t i = 0; i < shape.size(); ++i) { + count *= shape[i]; } total_elms = count; @@ -39,9 +46,8 @@ Tensor::Tensor(Tensor&& other) noexcept total_elms(other.total_elms), total_size(other.total_size), backend(other.backend), - d_ptr(other.d_ptr) -{ - other.d_ptr = nullptr; + d_ptr(other.d_ptr) { + other.d_ptr = nullptr; other.backend = nullptr; } @@ -51,17 +57,17 @@ Tensor& Tensor::operator=(Tensor&& other) noexcept { if (d_ptr != nullptr && backend != nullptr) { backend->deallocate(d_ptr); } - + // Steal other's resources - shape = std::move(other.shape); - dtype = other.dtype; + shape = std::move(other.shape); + dtype = other.dtype; total_elms = other.total_elms; total_size = other.total_size; - backend = other.backend; - d_ptr = other.d_ptr; - + backend = other.backend; + d_ptr = other.d_ptr; + // Leave other in valid but empty state - other.d_ptr = nullptr; + other.d_ptr = nullptr; other.backend = nullptr; } return *this;