Move dtype size implementation to cpp file

This commit is contained in:
2025-11-27 23:29:42 +01:00
parent 7e27c87673
commit 71dc5a924d
2 changed files with 15 additions and 12 deletions

View File

@@ -4,6 +4,19 @@
using namespace CUDANet;
size_t dtype_size(DType dtype) {
switch (dtype)
{
case DType::FLOAT32:
return 4;
break;
default:
throw std::runtime_error("Unknown DType");
break;
}
}
Tensor::Tensor(Shape shape, CUDANet::Backend* backend)
: Tensor(shape, backend->get_default_dtype(), backend) {}