mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-22 14:24:22 +00:00
Fix Tensor issues
This commit is contained in:
@@ -10,10 +10,10 @@ Dense::Dense(CUDANet::Backend* backend, CUDANet::Shape in, CUDANet::Shape out)
|
||||
in_shape(in),
|
||||
out_shape(out),
|
||||
weights(
|
||||
CUDANet::Tensor{{in[0] * out[0]}, CUDANet::DType::FLOAT32, backend}
|
||||
CUDANet::Tensor(Shape{in[0] * out[0]}, CUDANet::DType::FLOAT32, backend)
|
||||
),
|
||||
biases(CUDANet::Tensor({out[0]}, CUDANet::DType::FLOAT32, backend)),
|
||||
output(CUDANet::Tensor({out[0]}, CUDANet::DType::FLOAT32, backend)) {
|
||||
biases(CUDANet::Tensor(Shape{out[0]}, CUDANet::DType::FLOAT32, backend)),
|
||||
output(CUDANet::Tensor(Shape{out[0]}, CUDANet::DType::FLOAT32, backend)) {
|
||||
// Allocate memory for weights and biases
|
||||
|
||||
if (in.size() != 1) {
|
||||
@@ -35,6 +35,8 @@ Dense::Dense(CUDANet::Backend* backend, CUDANet::Shape in, CUDANet::Shape out)
|
||||
biases.zero();
|
||||
}
|
||||
|
||||
Dense::~Dense() {}
|
||||
|
||||
CUDANet::Tensor& Dense::forward(const CUDANet::Tensor& input) {
|
||||
backend->dense(weights, biases, input, output, in_shape[0], out_shape[0]);
|
||||
return output;
|
||||
|
||||
Reference in New Issue
Block a user