diff --git a/include/tensor.hpp b/include/tensor.hpp index 3b329c0..d72c6db 100644 --- a/include/tensor.hpp +++ b/include/tensor.hpp @@ -16,19 +16,9 @@ enum class DType // INT32, // Not implemented yet }; -size_t dtype_size(DType dtype) { - switch (dtype) - { - case DType::FLOAT32: - return 4; - break; - - default: - throw std::runtime_error("Unknown DType"); - break; - } -} +size_t dtype_size(DType dtype); +// Forward declaration class Backend; class Tensor diff --git a/src/tensor.cpp b/src/tensor.cpp index fbfd2b4..8412370 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -4,6 +4,19 @@ using namespace CUDANet; +size_t dtype_size(DType dtype) { + switch (dtype) + { + case DType::FLOAT32: + return 4; + break; + + default: + throw std::runtime_error("Unknown DType"); + break; + } +} + Tensor::Tensor(Shape shape, CUDANet::Backend* backend) : Tensor(shape, backend->get_default_dtype(), backend) {}