Move factory implementation out of header

This commit is contained in:
2025-11-24 22:01:54 +01:00
parent a40ba96d4f
commit 60964cf294
2 changed files with 43 additions and 32 deletions

View File

@@ -1,17 +1,14 @@
#pragma once #pragma once
#include <cstddef> #include <memory>
#include <stdexcept>
#include "shape.hpp" #include "shape.hpp"
#ifdef USE_CUDA
#include "backend/cuda/cuda.cuh"
#endif
namespace CUDANet { namespace CUDANet {
// Forward declaration // Forward declaration
class Tensor; class Tensor;
class Backend;
enum BackendType { CUDA_BACKEND, CPU_BACKEND }; enum BackendType { CUDA_BACKEND, CPU_BACKEND };
@@ -21,33 +18,7 @@ struct BackendConfig {
class BackendFactory { class BackendFactory {
public: public:
static std::unique_ptr<Backend> create(BackendType backend_type, const BackendConfig& config) { static std::unique_ptr<Backend> create(BackendType backend_type, const BackendConfig& config);
switch (backend_type)
{
case BackendType::CUDA_BACKEND:
#ifdef USE_CUDA
if (!CUDANet::Backends::CUDA::is_cuda_available()) {
throw std::runtime_error("No CUDA devices found")
}
auto cuda = std::make_unique<CUDANet::Backends::CUDA>(config);
cuda.initialize();
return cuda;
#else
throw std::runtime_error("Library was compiled without CUDA support.");
#endif
break;
default:
break;
}
return nullptr;
}
}; };
class Backend { class Backend {

40
src/backend_factory.cpp Normal file
View File

@@ -0,0 +1,40 @@
#include <stdexcept>
#include <memory>
#ifdef USE_CUDA
#include "backend/cuda/cuda.cuh"
#endif
#include "backend.hpp"
namespace CUDANet {
std::unique_ptr<Backend> BackendFactory::create(BackendType backend_type, const BackendConfig& config) {
switch (backend_type)
{
case BackendType::CUDA_BACKEND:
#ifdef USE_CUDA
if (!CUDANet::Backends::CUDA::is_cuda_available()) {
throw std::runtime_error("No CUDA devices found")
}
auto cuda = std::make_unique<CUDANet::Backends::CUDA>(config);
cuda.initialize();
return cuda;
#else
throw std::runtime_error("Library was compiled without CUDA support.");
#endif
break;
default:
break;
}
return nullptr;
}
} // namespace CUDANet