Add default dtype to backend

This commit is contained in:
2025-11-25 23:42:19 +01:00
parent ad079560ff
commit 84153ac49c
6 changed files with 73 additions and 16 deletions

View File

@@ -1,6 +1,7 @@
#pragma once
#include <cstdio>
#include <set>
#include "backend.hpp"
#include "tensor.hpp"
@@ -29,9 +30,14 @@ namespace CUDANet::Backends {
class CUDA : public Backend {
private:
int device_id;
std::set<DType> supported_dtypes;
public:
CUDA(const BackendConfig& config);
bool supports_dtype(DType dtype) const override;
void set_default_dtype(DType dtype) override;
DType get_default_dtype() const override;
static bool is_cuda_available();
void initialize();