mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-22 14:24:22 +00:00
Add dtype parameter to layer constructors
This commit is contained in:
@@ -16,6 +16,8 @@ namespace CUDANet {
|
||||
*
|
||||
*/
|
||||
class Layer {
|
||||
protected:
|
||||
CUDANet::DType dtype;
|
||||
public:
|
||||
|
||||
virtual ~Layer(){};
|
||||
@@ -39,4 +41,4 @@ class Layer {
|
||||
virtual size_t get_biases_size() = 0;
|
||||
};
|
||||
|
||||
} // namespace CUDANet::Layers
|
||||
} // namespace CUDANet
|
||||
|
||||
@@ -20,12 +20,13 @@ enum ActivationType { SIGMOID, RELU, SOFTMAX, NONE };
|
||||
* @brief Utility class that performs activation
|
||||
*
|
||||
*/
|
||||
class Activation : public Layer {
|
||||
class Activation : public CUDANet::Layer {
|
||||
public:
|
||||
|
||||
Activation() = default;
|
||||
|
||||
Activation(ActivationType activation, const CUDANet::Shape &shape, CUDANet::Backend* backend);
|
||||
Activation(ActivationType activation, const CUDANet::Shape &shape, CUDANet::DType dtype, CUDANet::Backend* backend);
|
||||
|
||||
~Activation() = default;
|
||||
|
||||
@@ -50,7 +51,7 @@ class Activation : public Layer {
|
||||
|
||||
private:
|
||||
CUDANet::Backend* backend;
|
||||
ActivationType activationType;
|
||||
ActivationType activation_type;
|
||||
CUDANet::Shape shape;
|
||||
|
||||
CUDANet::Tensor softmax_sum;
|
||||
|
||||
@@ -8,6 +8,7 @@ namespace CUDANet::Layers {
|
||||
class Add {
|
||||
public:
|
||||
Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::Backend* backend);
|
||||
Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::DType dtype, CUDANet::Backend* backend);
|
||||
|
||||
~Add();
|
||||
|
||||
@@ -19,6 +20,8 @@ class Add {
|
||||
CUDANet::Tensor output;
|
||||
|
||||
CUDANet::Backend *backend;
|
||||
|
||||
CUDANet::DType dtype;
|
||||
};
|
||||
|
||||
} // namespace CUDANet::Layers
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
namespace CUDANet::Layers {
|
||||
|
||||
class AvgPool2d : public Layer {
|
||||
class AvgPool2d : public CUDANet::Layer {
|
||||
public:
|
||||
AvgPool2d(
|
||||
CUDANet::Shape input_shape,
|
||||
@@ -13,6 +13,14 @@ class AvgPool2d : public Layer {
|
||||
CUDANet::Shape padding_shape,
|
||||
CUDANet::Backend *backend
|
||||
);
|
||||
AvgPool2d(
|
||||
CUDANet::Shape input_shape,
|
||||
CUDANet::Shape pool_shape,
|
||||
CUDANet::Shape stride_shape,
|
||||
CUDANet::Shape padding_shape,
|
||||
CUDANet::DType dtype,
|
||||
CUDANet::Backend *backend
|
||||
);
|
||||
|
||||
~AvgPool2d();
|
||||
|
||||
@@ -50,6 +58,7 @@ class AvgPool2d : public Layer {
|
||||
class AdaptiveAvgPool2d : public AvgPool2d {
|
||||
public:
|
||||
AdaptiveAvgPool2d(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::Backend *backend);
|
||||
AdaptiveAvgPool2d(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::DType dtype, CUDANet::Backend *backend);
|
||||
};
|
||||
|
||||
} // namespace CUDANet::Layers
|
||||
|
||||
@@ -4,9 +4,10 @@
|
||||
|
||||
namespace CUDANet::Layers {
|
||||
|
||||
class BatchNorm2d : public Layer {
|
||||
class BatchNorm2d : public CUDANet::Layer {
|
||||
public:
|
||||
BatchNorm2d(CUDANet::Shape input_shape, float epsilon, CUDANet::Backend *backend);
|
||||
BatchNorm2d(CUDANet::Shape input_shape, float epsilon, CUDANet::DType dtype, CUDANet::Backend *backend);
|
||||
|
||||
~BatchNorm2d();
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ class Concat {
|
||||
public:
|
||||
|
||||
Concat(const CUDANet::Shape a_shape, const CUDANet::Shape b_shape, CUDANet::Backend *backend);
|
||||
Concat(const CUDANet::Shape a_shape, const CUDANet::Shape b_shape, CUDANet::DType dtype, CUDANet::Backend *backend);
|
||||
|
||||
~Concat();
|
||||
|
||||
@@ -27,6 +28,8 @@ class Concat {
|
||||
CUDANet::Tensor output;
|
||||
|
||||
CUDANet::Backend *backend;
|
||||
|
||||
CUDANet::DType dtype;
|
||||
};
|
||||
|
||||
} // namespace CUDANet::Layers
|
||||
|
||||
@@ -8,7 +8,7 @@ namespace CUDANet::Layers {
|
||||
* @brief 2D convolutional layer
|
||||
*
|
||||
*/
|
||||
class Conv2d : public Layer {
|
||||
class Conv2d : public CUDANet::Layer {
|
||||
public:
|
||||
Conv2d(
|
||||
CUDANet::Shape input_shape,
|
||||
@@ -17,6 +17,14 @@ class Conv2d : public Layer {
|
||||
CUDANet::Shape padding_shape,
|
||||
CUDANet::Backend* backend
|
||||
);
|
||||
Conv2d(
|
||||
CUDANet::Shape input_shape,
|
||||
CUDANet::Shape kernel_shape,
|
||||
CUDANet::Shape stride_shape,
|
||||
CUDANet::Shape padding_shape,
|
||||
CUDANet::DType dtype,
|
||||
CUDANet::Backend* backend
|
||||
);
|
||||
|
||||
~Conv2d();
|
||||
|
||||
|
||||
@@ -9,10 +9,11 @@ namespace CUDANet::Layers {
|
||||
* @brief Dense (fully connected) layer
|
||||
*
|
||||
*/
|
||||
class Dense : public Layer {
|
||||
class Dense : public CUDANet::Layer {
|
||||
public:
|
||||
|
||||
Dense(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::Backend *backend);
|
||||
Dense(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::DType dtype, CUDANet::Backend *backend);
|
||||
|
||||
~Dense();
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
namespace CUDANet::Layers {
|
||||
|
||||
class MaxPool2d : public Layer {
|
||||
class MaxPool2d : public CUDANet::Layer {
|
||||
public:
|
||||
MaxPool2d(
|
||||
CUDANet::Shape input_shape,
|
||||
@@ -13,6 +13,14 @@ class MaxPool2d : public Layer {
|
||||
CUDANet::Shape padding_shape,
|
||||
CUDANet::Backend* backend
|
||||
);
|
||||
MaxPool2d(
|
||||
CUDANet::Shape input_shape,
|
||||
CUDANet::Shape pool_shape,
|
||||
CUDANet::Shape stride_shape,
|
||||
CUDANet::Shape padding_shape,
|
||||
CUDANet::DType dtype,
|
||||
CUDANet::Backend* backend
|
||||
);
|
||||
~MaxPool2d();
|
||||
|
||||
CUDANet::Tensor& forward(CUDANet::Tensor &input) override;
|
||||
|
||||
Reference in New Issue
Block a user