From a40ba96d4fbef3bc8ff6b871cd1d7a95e4f3fb9d Mon Sep 17 00:00:00 2001 From: LordMathis Date: Mon, 24 Nov 2025 21:53:47 +0100 Subject: [PATCH] Implement backend factory --- include/backend.hpp | 41 +++++++++++++++ .../cuda/{cuda_backend.cuh => all.cuh} | 0 include/backend/cuda/cuda.cuh | 7 +++ include/{cudanet.cuh => cudanet.hpp} | 6 +-- include/{utils => datasets}/imagenet.hpp | 0 include/shape.hpp | 1 + src/backends/cuda/cuda.cu | 52 +++++++++++++++++++ src/backends/cuda/cuda_backend.cu | 38 -------------- 8 files changed, 104 insertions(+), 41 deletions(-) rename include/backend/cuda/{cuda_backend.cuh => all.cuh} (100%) rename include/{cudanet.cuh => cudanet.hpp} (94%) rename include/{utils => datasets}/imagenet.hpp (100%) create mode 100644 src/backends/cuda/cuda.cu delete mode 100644 src/backends/cuda/cuda_backend.cu diff --git a/include/backend.hpp b/include/backend.hpp index 07359e2..9242dbe 100644 --- a/include/backend.hpp +++ b/include/backend.hpp @@ -1,14 +1,55 @@ #pragma once #include +#include #include "shape.hpp" +#ifdef USE_CUDA +#include "backend/cuda/cuda.cuh" +#endif namespace CUDANet { // Forward declaration class Tensor; +enum BackendType { CUDA_BACKEND, CPU_BACKEND }; + +struct BackendConfig { + int device_id = 0; +}; + +class BackendFactory { + public: + static std::unique_ptr create(BackendType backend_type, const BackendConfig& config) { + switch (backend_type) + { + case BackendType::CUDA_BACKEND: + #ifdef USE_CUDA + + if (!CUDANet::Backends::CUDA::is_cuda_available()) { + throw std::runtime_error("No CUDA devices found") + } + + auto cuda = std::make_unique(config); + cuda.initialize(); + + return cuda; + + #else + throw std::runtime_error("Library was compiled without CUDA support."); + #endif + + break; + + default: + break; + } + + return nullptr; + } +}; + class Backend { public: // Memory management diff --git a/include/backend/cuda/cuda_backend.cuh b/include/backend/cuda/all.cuh similarity index 100% rename from include/backend/cuda/cuda_backend.cuh rename to include/backend/cuda/all.cuh diff --git a/include/backend/cuda/cuda.cuh b/include/backend/cuda/cuda.cuh index e60f374..10d2979 100644 --- a/include/backend/cuda/cuda.cuh +++ b/include/backend/cuda/cuda.cuh @@ -27,7 +27,14 @@ do { \ namespace CUDANet::Backends { class CUDA : public Backend { + private: + int device_id; public: + CUDA(const BackendConfig& config); + + static bool is_cuda_available(); + void initialize(); + // Memory management void* allocate(size_t bytes) override; void deallocate(void* ptr) override; diff --git a/include/cudanet.cuh b/include/cudanet.hpp similarity index 94% rename from include/cudanet.cuh rename to include/cudanet.hpp index 96067f5..5bcd4d6 100644 --- a/include/cudanet.cuh +++ b/include/cudanet.hpp @@ -41,15 +41,15 @@ #include "layers/concat.hpp" // ============================================================================ -// Utilities +// Dataset Labels // ============================================================================ -#include "utils/imagenet.hpp" +#include "datasets/imagenet.hpp" // ============================================================================ // Backend-Specific Includes (conditionally compiled) // ============================================================================ #ifdef USE_CUDA -#include "backend/cuda/cuda_backend.cuh" +#include "backend/cuda/all.cuh" #endif diff --git a/include/utils/imagenet.hpp b/include/datasets/imagenet.hpp similarity index 100% rename from include/utils/imagenet.hpp rename to include/datasets/imagenet.hpp diff --git a/include/shape.hpp b/include/shape.hpp index e9db6ee..f54a1cf 100644 --- a/include/shape.hpp +++ b/include/shape.hpp @@ -9,6 +9,7 @@ #endif #include +#include #include namespace CUDANet { diff --git a/src/backends/cuda/cuda.cu b/src/backends/cuda/cuda.cu new file mode 100644 index 0000000..8d51d24 --- /dev/null +++ b/src/backends/cuda/cuda.cu @@ -0,0 +1,52 @@ +#include + +#include +#include +#include + +#include "backend/cuda/cuda.cuh" + +using namespace CUDANet::Backends; + + +CUDA::CUDA(const BackendConfig& config) { + device_id = config.device_id < 0 ? 0 : config.device_id; + initialize(); +} + +bool CUDA::is_cuda_available() { + int device_count; + cudaError_t result = cudaGetDeviceCount(&device_count); + + // Return false instead of crashing + if (result != cudaSuccess || device_count == 0) { + return false; + } + return true; +} + +void CUDA::initialize() { + + int device_count; + CUDA_CHECK(cudaGetDeviceCount(&device_count)); + if (device_id >= device_count) { + throw std::runtime_error(std::format("Invalid device id {}, only {} devices available", device_id, device_count)); + } + + CUDA_CHECK(cudaSetDevice(device_id)); + + cudaDeviceProp deviceProp; + CUDA_CHECK(cudaGetDeviceProperties(&deviceProp, device_id)); + + std::printf("Using CUDA device %d: %s\n", device_id, deviceProp.name); +} + +void* CUDA::allocate(size_t bytes) { + void* d_ptr = nullptr; + CUDA_CHECK(cudaMalloc(&d_ptr, bytes)); + return d_ptr; +} + +void CUDA::deallocate(void* ptr) { + CUDA_CHECK(cudaFree(ptr)); +} diff --git a/src/backends/cuda/cuda_backend.cu b/src/backends/cuda/cuda_backend.cu deleted file mode 100644 index 805a163..0000000 --- a/src/backends/cuda/cuda_backend.cu +++ /dev/null @@ -1,38 +0,0 @@ -#include - -#include -#include - -#include "backend/cuda/cuda.cuh" - -cudaDeviceProp initializeCUDA() { - int deviceCount; - CUDA_CHECK(cudaGetDeviceCount(&deviceCount)); - - if (deviceCount == 0) { - std::fprintf(stderr, "No CUDA devices found. Exiting.\n"); - std::exit(EXIT_FAILURE); - } - - int device = 0; - CUDA_CHECK(cudaSetDevice(device)); - - cudaDeviceProp deviceProp; - CUDA_CHECK(cudaGetDeviceProperties(&deviceProp, device)); - - std::printf("Using CUDA device %d: %s\n", device, deviceProp.name); - - return deviceProp; -} - -using namespace CUDANet::Backends; - -void* CUDA::allocate(size_t bytes) { - void* d_ptr = nullptr; - CUDA_CHECK(cudaMalloc(&d_ptr, bytes)); - return d_ptr; -} - -void CUDA::deallocate(void* ptr) { - CUDA_CHECK(cudaFree(ptr)); -}