Add default dtype to backend

This commit is contained in:
2025-11-25 23:42:19 +01:00
parent ad079560ff
commit 84153ac49c
6 changed files with 73 additions and 16 deletions

View File

@@ -1,8 +1,10 @@
#pragma once
#include <memory>
#include <optional>
#include "shape.hpp"
#include "tensor.hpp"
namespace CUDANet {
@@ -22,7 +24,14 @@ class BackendFactory {
};
class Backend {
protected:
std::optional<DType> default_dtype;
public:
virtual bool supports_dtype(DType dtype) const = 0;
virtual void set_default_dtype(DType dtype) = 0;
virtual DType get_default_dtype() const = 0;
// Memory management
virtual void* allocate(size_t bytes) = 0;
virtual void deallocate(void* ptr) = 0;