Add dtype parameter to layer constructors

This commit is contained in:
2025-11-26 00:19:33 +01:00
parent 84153ac49c
commit 13d3d38b68
17 changed files with 169 additions and 49 deletions

View File

@@ -4,9 +4,10 @@
namespace CUDANet::Layers {
class BatchNorm2d : public Layer {
class BatchNorm2d : public CUDANet::Layer {
public:
BatchNorm2d(CUDANet::Shape input_shape, float epsilon, CUDANet::Backend *backend);
BatchNorm2d(CUDANet::Shape input_shape, float epsilon, CUDANet::DType dtype, CUDANet::Backend *backend);
~BatchNorm2d();