mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-22 22:34:22 +00:00
Add default dtype to backend
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
#include "shape.hpp"
|
||||
#include "tensor.hpp"
|
||||
|
||||
namespace CUDANet {
|
||||
|
||||
@@ -22,7 +24,14 @@ class BackendFactory {
|
||||
};
|
||||
|
||||
class Backend {
|
||||
protected:
|
||||
std::optional<DType> default_dtype;
|
||||
public:
|
||||
|
||||
virtual bool supports_dtype(DType dtype) const = 0;
|
||||
virtual void set_default_dtype(DType dtype) = 0;
|
||||
virtual DType get_default_dtype() const = 0;
|
||||
|
||||
// Memory management
|
||||
virtual void* allocate(size_t bytes) = 0;
|
||||
virtual void deallocate(void* ptr) = 0;
|
||||
|
||||
Reference in New Issue
Block a user