diff --git a/include/backend.hpp b/include/backend.hpp index 9242dbe..6a41d61 100644 --- a/include/backend.hpp +++ b/include/backend.hpp @@ -1,17 +1,14 @@ #pragma once -#include -#include +#include #include "shape.hpp" -#ifdef USE_CUDA -#include "backend/cuda/cuda.cuh" -#endif namespace CUDANet { // Forward declaration class Tensor; +class Backend; enum BackendType { CUDA_BACKEND, CPU_BACKEND }; @@ -21,33 +18,7 @@ struct BackendConfig { class BackendFactory { public: - static std::unique_ptr create(BackendType backend_type, const BackendConfig& config) { - switch (backend_type) - { - case BackendType::CUDA_BACKEND: - #ifdef USE_CUDA - - if (!CUDANet::Backends::CUDA::is_cuda_available()) { - throw std::runtime_error("No CUDA devices found") - } - - auto cuda = std::make_unique(config); - cuda.initialize(); - - return cuda; - - #else - throw std::runtime_error("Library was compiled without CUDA support."); - #endif - - break; - - default: - break; - } - - return nullptr; - } + static std::unique_ptr create(BackendType backend_type, const BackendConfig& config); }; class Backend { diff --git a/src/backend_factory.cpp b/src/backend_factory.cpp new file mode 100644 index 0000000..beaea89 --- /dev/null +++ b/src/backend_factory.cpp @@ -0,0 +1,40 @@ +#include +#include + +#ifdef USE_CUDA +#include "backend/cuda/cuda.cuh" +#endif + +#include "backend.hpp" + +namespace CUDANet { + +std::unique_ptr BackendFactory::create(BackendType backend_type, const BackendConfig& config) { + switch (backend_type) + { + case BackendType::CUDA_BACKEND: + #ifdef USE_CUDA + + if (!CUDANet::Backends::CUDA::is_cuda_available()) { + throw std::runtime_error("No CUDA devices found") + } + + auto cuda = std::make_unique(config); + cuda.initialize(); + + return cuda; + + #else + throw std::runtime_error("Library was compiled without CUDA support."); + #endif + + break; + + default: + break; + } + + return nullptr; +} + +} // namespace CUDANet \ No newline at end of file