mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-22 14:24:22 +00:00
Refactor size calculations in layers and backend
This commit is contained in:
@@ -28,6 +28,7 @@ class Backend {
|
||||
std::optional<DType> default_dtype;
|
||||
public:
|
||||
|
||||
// Dtypes
|
||||
virtual bool supports_dtype(DType dtype) const = 0;
|
||||
virtual void set_default_dtype(DType dtype) = 0;
|
||||
virtual DType get_default_dtype() const = 0;
|
||||
|
||||
@@ -15,10 +15,6 @@ class Module {
|
||||
|
||||
CUDANet::Shape output_shape();
|
||||
|
||||
size_t input_size();
|
||||
|
||||
size_t output_size();
|
||||
|
||||
void register_layer(const std::string& name, Layer* layer);
|
||||
|
||||
void register_module(Module& module);
|
||||
|
||||
@@ -16,6 +16,19 @@ enum class DType
|
||||
// INT32, // Not implemented yet
|
||||
};
|
||||
|
||||
size_t dtype_size(DType dtype) {
|
||||
switch (dtype)
|
||||
{
|
||||
case DType::FLOAT32:
|
||||
return 4;
|
||||
break;
|
||||
|
||||
default:
|
||||
throw std::runtime_error("Unknown DType");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
class Tensor
|
||||
{
|
||||
public:
|
||||
|
||||
Reference in New Issue
Block a user