Add dtype parameter to layer constructors

This commit is contained in:
2025-11-26 00:19:33 +01:00
parent 84153ac49c
commit 13d3d38b68
17 changed files with 169 additions and 49 deletions

View File

@@ -16,6 +16,8 @@ namespace CUDANet {
*
*/
class Layer {
protected:
CUDANet::DType dtype;
public:
virtual ~Layer(){};
@@ -39,4 +41,4 @@ class Layer {
virtual size_t get_biases_size() = 0;
};
} // namespace CUDANet::Layers
} // namespace CUDANet

View File

@@ -20,12 +20,13 @@ enum ActivationType { SIGMOID, RELU, SOFTMAX, NONE };
* @brief Utility class that performs activation
*
*/
class Activation : public Layer {
class Activation : public CUDANet::Layer {
public:
Activation() = default;
Activation(ActivationType activation, const CUDANet::Shape &shape, CUDANet::Backend* backend);
Activation(ActivationType activation, const CUDANet::Shape &shape, CUDANet::DType dtype, CUDANet::Backend* backend);
~Activation() = default;
@@ -50,7 +51,7 @@ class Activation : public Layer {
private:
CUDANet::Backend* backend;
ActivationType activationType;
ActivationType activation_type;
CUDANet::Shape shape;
CUDANet::Tensor softmax_sum;

View File

@@ -8,6 +8,7 @@ namespace CUDANet::Layers {
class Add {
public:
Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::Backend* backend);
Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::DType dtype, CUDANet::Backend* backend);
~Add();
@@ -19,6 +20,8 @@ class Add {
CUDANet::Tensor output;
CUDANet::Backend *backend;
CUDANet::DType dtype;
};
} // namespace CUDANet::Layers

View File

@@ -4,7 +4,7 @@
namespace CUDANet::Layers {
class AvgPool2d : public Layer {
class AvgPool2d : public CUDANet::Layer {
public:
AvgPool2d(
CUDANet::Shape input_shape,
@@ -13,6 +13,14 @@ class AvgPool2d : public Layer {
CUDANet::Shape padding_shape,
CUDANet::Backend *backend
);
AvgPool2d(
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::DType dtype,
CUDANet::Backend *backend
);
~AvgPool2d();
@@ -50,6 +58,7 @@ class AvgPool2d : public Layer {
class AdaptiveAvgPool2d : public AvgPool2d {
public:
AdaptiveAvgPool2d(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::Backend *backend);
AdaptiveAvgPool2d(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::DType dtype, CUDANet::Backend *backend);
};
} // namespace CUDANet::Layers

View File

@@ -4,9 +4,10 @@
namespace CUDANet::Layers {
class BatchNorm2d : public Layer {
class BatchNorm2d : public CUDANet::Layer {
public:
BatchNorm2d(CUDANet::Shape input_shape, float epsilon, CUDANet::Backend *backend);
BatchNorm2d(CUDANet::Shape input_shape, float epsilon, CUDANet::DType dtype, CUDANet::Backend *backend);
~BatchNorm2d();

View File

@@ -12,6 +12,7 @@ class Concat {
public:
Concat(const CUDANet::Shape a_shape, const CUDANet::Shape b_shape, CUDANet::Backend *backend);
Concat(const CUDANet::Shape a_shape, const CUDANet::Shape b_shape, CUDANet::DType dtype, CUDANet::Backend *backend);
~Concat();
@@ -27,6 +28,8 @@ class Concat {
CUDANet::Tensor output;
CUDANet::Backend *backend;
CUDANet::DType dtype;
};
} // namespace CUDANet::Layers

View File

@@ -8,7 +8,7 @@ namespace CUDANet::Layers {
* @brief 2D convolutional layer
*
*/
class Conv2d : public Layer {
class Conv2d : public CUDANet::Layer {
public:
Conv2d(
CUDANet::Shape input_shape,
@@ -17,6 +17,14 @@ class Conv2d : public Layer {
CUDANet::Shape padding_shape,
CUDANet::Backend* backend
);
Conv2d(
CUDANet::Shape input_shape,
CUDANet::Shape kernel_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::DType dtype,
CUDANet::Backend* backend
);
~Conv2d();

View File

@@ -9,10 +9,11 @@ namespace CUDANet::Layers {
* @brief Dense (fully connected) layer
*
*/
class Dense : public Layer {
class Dense : public CUDANet::Layer {
public:
Dense(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::Backend *backend);
Dense(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::DType dtype, CUDANet::Backend *backend);
~Dense();

View File

@@ -4,7 +4,7 @@
namespace CUDANet::Layers {
class MaxPool2d : public Layer {
class MaxPool2d : public CUDANet::Layer {
public:
MaxPool2d(
CUDANet::Shape input_shape,
@@ -13,6 +13,14 @@ class MaxPool2d : public Layer {
CUDANet::Shape padding_shape,
CUDANet::Backend* backend
);
MaxPool2d(
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::DType dtype,
CUDANet::Backend* backend
);
~MaxPool2d();
CUDANet::Tensor& forward(CUDANet::Tensor &input) override;