Add dtype parameter to layer constructors

This commit is contained in:
2025-11-26 00:19:33 +01:00
parent 84153ac49c
commit 13d3d38b68
17 changed files with 169 additions and 49 deletions

View File

@@ -3,7 +3,11 @@
using namespace CUDANet::Layers;
Add::Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::Backend* backend) : backend(backend) {
Add::Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::Backend* backend)
: Add(a_shape, b_shape, backend->get_default_dtype(), backend) {}
Add::Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::DType dtype, CUDANet::Backend* backend)
: backend(backend), dtype(dtype) {
if (a_shape != b_shape) {
throw InvalidShapeException(
"Add requires matching dimensions", a_shape, b_shape
@@ -11,7 +15,7 @@ Add::Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::Backend* backe
}
out_shape = a_shape;
output = CUDANet::Tensor(out_shape, CUDANet::DType::FLOAT32, backend);
output = CUDANet::Tensor(out_shape, dtype, backend);
}
Add::~Add() {}