mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-22 22:34:22 +00:00
Add dtype parameter to layer constructors
This commit is contained in:
@@ -3,7 +3,11 @@
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
|
||||
Add::Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::Backend* backend) : backend(backend) {
|
||||
Add::Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::Backend* backend)
|
||||
: Add(a_shape, b_shape, backend->get_default_dtype(), backend) {}
|
||||
|
||||
Add::Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::DType dtype, CUDANet::Backend* backend)
|
||||
: backend(backend), dtype(dtype) {
|
||||
if (a_shape != b_shape) {
|
||||
throw InvalidShapeException(
|
||||
"Add requires matching dimensions", a_shape, b_shape
|
||||
@@ -11,7 +15,7 @@ Add::Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::Backend* backe
|
||||
}
|
||||
|
||||
out_shape = a_shape;
|
||||
output = CUDANet::Tensor(out_shape, CUDANet::DType::FLOAT32, backend);
|
||||
output = CUDANet::Tensor(out_shape, dtype, backend);
|
||||
}
|
||||
|
||||
Add::~Add() {}
|
||||
|
||||
Reference in New Issue
Block a user