Refactor size calculations in layers and backend

This commit is contained in:
2025-11-27 22:01:09 +01:00
parent c855ae89ec
commit e79667671a
13 changed files with 58 additions and 64 deletions

View File

@@ -28,6 +28,7 @@ class Backend {
std::optional<DType> default_dtype;
public:
// Dtypes
virtual bool supports_dtype(DType dtype) const = 0;
virtual void set_default_dtype(DType dtype) = 0;
virtual DType get_default_dtype() const = 0;

View File

@@ -15,10 +15,6 @@ class Module {
CUDANet::Shape output_shape();
size_t input_size();
size_t output_size();
void register_layer(const std::string& name, Layer* layer);
void register_module(Module& module);

View File

@@ -16,6 +16,19 @@ enum class DType
// INT32, // Not implemented yet
};
size_t dtype_size(DType dtype) {
switch (dtype)
{
case DType::FLOAT32:
return 4;
break;
default:
throw std::runtime_error("Unknown DType");
break;
}
}
class Tensor
{
public: