Move dtype size implementation to cpp file

This commit is contained in:
2025-11-27 23:29:42 +01:00
parent 7e27c87673
commit 71dc5a924d
2 changed files with 15 additions and 12 deletions

View File

@@ -16,19 +16,9 @@ enum class DType
// INT32, // Not implemented yet
};
size_t dtype_size(DType dtype) {
switch (dtype)
{
case DType::FLOAT32:
return 4;
break;
default:
throw std::runtime_error("Unknown DType");
break;
}
}
size_t dtype_size(DType dtype);
// Forward declaration
class Backend;
class Tensor

View File

@@ -4,6 +4,19 @@
using namespace CUDANet;
size_t dtype_size(DType dtype) {
switch (dtype)
{
case DType::FLOAT32:
return 4;
break;
default:
throw std::runtime_error("Unknown DType");
break;
}
}
Tensor::Tensor(Shape shape, CUDANet::Backend* backend)
: Tensor(shape, backend->get_default_dtype(), backend) {}