Add dtype parameter to layer constructors

This commit is contained in:
2025-11-26 00:19:33 +01:00
parent 84153ac49c
commit 13d3d38b68
17 changed files with 169 additions and 49 deletions

View File

@@ -8,6 +8,7 @@ namespace CUDANet::Layers {
class Add {
public:
Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::Backend* backend);
Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::DType dtype, CUDANet::Backend* backend);
~Add();
@@ -19,6 +20,8 @@ class Add {
CUDANet::Tensor output;
CUDANet::Backend *backend;
CUDANet::DType dtype;
};
} // namespace CUDANet::Layers