mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-23 23:04:25 +00:00
Add dtype parameter to layer constructors
This commit is contained in:
@@ -8,6 +8,7 @@ namespace CUDANet::Layers {
|
||||
class Add {
|
||||
public:
|
||||
Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::Backend* backend);
|
||||
Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::DType dtype, CUDANet::Backend* backend);
|
||||
|
||||
~Add();
|
||||
|
||||
@@ -19,6 +20,8 @@ class Add {
|
||||
CUDANet::Tensor output;
|
||||
|
||||
CUDANet::Backend *backend;
|
||||
|
||||
CUDANet::DType dtype;
|
||||
};
|
||||
|
||||
} // namespace CUDANet::Layers
|
||||
|
||||
Reference in New Issue
Block a user