Add default dtype to backend

This commit is contained in:
2025-11-25 23:42:19 +01:00
parent ad079560ff
commit 84153ac49c
6 changed files with 73 additions and 16 deletions

View File

@@ -16,11 +16,14 @@ enum class DType
// INT32, // Not implemented yet
};
size_t dtype_size(DType dtype);
class Tensor
{
public:
Tensor() = default;
Tensor(Shape shape, CUDANet::Backend* backend);
Tensor(Shape shape, DType dtype, CUDANet::Backend* backend);
Tensor(Tensor&& other) noexcept;
@@ -30,6 +33,8 @@ public:
~Tensor();
DType get_dtype();
size_t size() const;
size_t numel() const;