mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-23 06:44:24 +00:00
Add default dtype to backend
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdio>
|
||||
#include <set>
|
||||
|
||||
#include "backend.hpp"
|
||||
#include "tensor.hpp"
|
||||
@@ -29,9 +30,14 @@ namespace CUDANet::Backends {
|
||||
class CUDA : public Backend {
|
||||
private:
|
||||
int device_id;
|
||||
std::set<DType> supported_dtypes;
|
||||
public:
|
||||
CUDA(const BackendConfig& config);
|
||||
|
||||
bool supports_dtype(DType dtype) const override;
|
||||
void set_default_dtype(DType dtype) override;
|
||||
DType get_default_dtype() const override;
|
||||
|
||||
static bool is_cuda_available();
|
||||
void initialize();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user