Add dtype parameter to layer constructors

This commit is contained in:
2025-11-26 00:19:33 +01:00
parent 84153ac49c
commit 13d3d38b68
17 changed files with 169 additions and 49 deletions

View File

@@ -4,7 +4,7 @@
namespace CUDANet::Layers {
class AvgPool2d : public Layer {
class AvgPool2d : public CUDANet::Layer {
public:
AvgPool2d(
CUDANet::Shape input_shape,
@@ -13,6 +13,14 @@ class AvgPool2d : public Layer {
CUDANet::Shape padding_shape,
CUDANet::Backend *backend
);
AvgPool2d(
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::DType dtype,
CUDANet::Backend *backend
);
~AvgPool2d();
@@ -50,6 +58,7 @@ class AvgPool2d : public Layer {
class AdaptiveAvgPool2d : public AvgPool2d {
public:
AdaptiveAvgPool2d(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::Backend *backend);
AdaptiveAvgPool2d(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::DType dtype, CUDANet::Backend *backend);
};
} // namespace CUDANet::Layers