Add dtype parameter to layer constructors

This commit is contained in:
2025-11-26 00:19:33 +01:00
parent 84153ac49c
commit 13d3d38b68
17 changed files with 169 additions and 49 deletions

View File

@@ -4,7 +4,7 @@
namespace CUDANet::Layers {
class MaxPool2d : public Layer {
class MaxPool2d : public CUDANet::Layer {
public:
MaxPool2d(
CUDANet::Shape input_shape,
@@ -13,6 +13,14 @@ class MaxPool2d : public Layer {
CUDANet::Shape padding_shape,
CUDANet::Backend* backend
);
MaxPool2d(
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::DType dtype,
CUDANet::Backend* backend
);
~MaxPool2d();
CUDANet::Tensor& forward(CUDANet::Tensor &input) override;