Add dtype parameter to layer constructors

This commit is contained in:
2025-11-26 00:19:33 +01:00
parent 84153ac49c
commit 13d3d38b68
17 changed files with 169 additions and 49 deletions

View File

@@ -8,7 +8,7 @@ namespace CUDANet::Layers {
* @brief 2D convolutional layer
*
*/
class Conv2d : public Layer {
class Conv2d : public CUDANet::Layer {
public:
Conv2d(
CUDANet::Shape input_shape,
@@ -17,6 +17,14 @@ class Conv2d : public Layer {
CUDANet::Shape padding_shape,
CUDANet::Backend* backend
);
Conv2d(
CUDANet::Shape input_shape,
CUDANet::Shape kernel_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::DType dtype,
CUDANet::Backend* backend
);
~Conv2d();