Add dtype parameter to layer constructors

This commit is contained in:
2025-11-26 00:19:33 +01:00
parent 84153ac49c
commit 13d3d38b68
17 changed files with 169 additions and 49 deletions

View File

@@ -10,6 +10,16 @@ MaxPool2d::MaxPool2d(
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::Backend* backend
)
: MaxPool2d(input_shape, pool_shape, stride_shape, padding_shape, backend->get_default_dtype(), backend) {}
MaxPool2d::MaxPool2d(
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::DType dtype,
CUDANet::Backend* backend
)
: in_shape(input_shape),
pool_shape(pool_shape),
@@ -32,6 +42,8 @@ MaxPool2d::MaxPool2d(
throw InvalidShapeException("padding", 2, padding_shape.size());
}
this->dtype = dtype;
out_shape = {
(in_shape[0] + 2 * padding_shape[0] - pool_shape[0]) / stride_shape[0] +
1,
@@ -42,7 +54,7 @@ MaxPool2d::MaxPool2d(
output = CUDANet::Tensor(
Shape{out_shape[0] * out_shape[1] * out_shape[2]},
CUDANet::DType::FLOAT32, backend
dtype, backend
);
}