Add dtype parameter to layer constructors

This commit is contained in:
2025-11-26 00:19:33 +01:00
parent 84153ac49c
commit 13d3d38b68
17 changed files with 169 additions and 49 deletions

View File

@@ -12,6 +12,7 @@ class Concat {
public:
Concat(const CUDANet::Shape a_shape, const CUDANet::Shape b_shape, CUDANet::Backend *backend);
Concat(const CUDANet::Shape a_shape, const CUDANet::Shape b_shape, CUDANet::DType dtype, CUDANet::Backend *backend);
~Concat();
@@ -27,6 +28,8 @@ class Concat {
CUDANet::Tensor output;
CUDANet::Backend *backend;
CUDANet::DType dtype;
};
} // namespace CUDANet::Layers