mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-22 14:24:22 +00:00
33 lines
811 B
C++
33 lines
811 B
C++
#include "layers/add.hpp"
|
|
|
|
using namespace CUDANet::Layers;
|
|
|
|
|
|
Add::Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::Backend* backend)
|
|
: Add(a_shape, b_shape, backend->get_default_dtype(), backend) {}
|
|
|
|
Add::Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::DType dtype, CUDANet::Backend* backend)
|
|
: backend(backend), dtype(dtype) {
|
|
if (a_shape != b_shape) {
|
|
throw InvalidShapeException(
|
|
"Add requires matching dimensions", a_shape, b_shape
|
|
);
|
|
}
|
|
|
|
out_shape = a_shape;
|
|
output = CUDANet::Tensor(out_shape, dtype, backend);
|
|
}
|
|
|
|
Add::~Add() {}
|
|
|
|
CUDANet::Tensor&
|
|
Add::forward(CUDANet::Tensor& input_a, CUDANet::Tensor& input_b) {
|
|
output.zero();
|
|
backend->add(
|
|
input_a,
|
|
input_b,
|
|
output
|
|
);
|
|
return output;
|
|
}
|