mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-22 14:24:22 +00:00
Fix some dense layer issues
This commit is contained in:
@@ -1,34 +1,41 @@
|
||||
#include "dense.hpp"
|
||||
|
||||
#include <format>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "dense.hpp"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
Dense::Dense(CUDANet::Backend *backend, CUDANet::Shape input_shape, CUDANet::Shape output_shape)
|
||||
: backend(backend), in_shape(input_shape), out_shape(output_shape) {
|
||||
Dense::Dense(CUDANet::Backend* backend, CUDANet::Shape in, CUDANet::Shape out)
|
||||
: backend(backend),
|
||||
in_shape(in),
|
||||
out_shape(out),
|
||||
weights(
|
||||
CUDANet::Tensor{{in[0] * out[0]}, CUDANet::DType::FLOAT32, backend}
|
||||
),
|
||||
biases(CUDANet::Tensor({out[0]}, CUDANet::DType::FLOAT32, backend)),
|
||||
output(CUDANet::Tensor({out[0]}, CUDANet::DType::FLOAT32, backend)) {
|
||||
// Allocate memory for weights and biases
|
||||
|
||||
if (input_shape.size() != 1) {
|
||||
throw std::runtime_error(std::format("Invalid shape. Expected [1], got {}", input_shape));
|
||||
}
|
||||
|
||||
if (output_shape.size() != 1) {
|
||||
throw std::runtime_error(std::format("Invalid shape. Expected [1], got {}", output_shape));
|
||||
if (in.size() != 1) {
|
||||
throw std::runtime_error(
|
||||
std::format("Invalid shape. Expected [1], got {}", in)
|
||||
);
|
||||
}
|
||||
|
||||
auto input_len = input_shape[0];
|
||||
auto output_len = output_shape[0];
|
||||
if (out.size() != 1) {
|
||||
throw std::runtime_error(
|
||||
std::format("Invalid shape. Expected [1], got {}", out)
|
||||
);
|
||||
}
|
||||
|
||||
auto weights = CUDANet::Tensor{Shape(input_len * output_len), CUDANet::DType::FLOAT32, backend};
|
||||
auto biases = CUDANet::Tensor(Shape(output_len), CUDANet::DType::FLOAT32, backend);
|
||||
auto output = CUDANet::Tensor(Shape(output_len), CUDANet::DType::FLOAT32, backend);
|
||||
auto input_len = in[0];
|
||||
auto output_len = out[0];
|
||||
|
||||
weights.zero();
|
||||
biases.zero();
|
||||
}
|
||||
|
||||
CUDANet::Tensor& Dense::forward(CUDANet::Tensor &input) {
|
||||
CUDANet::Tensor& Dense::forward(const CUDANet::Tensor& input) {
|
||||
backend->dense(weights, biases, input, output, in_shape[0], out_shape[0]);
|
||||
return output;
|
||||
}
|
||||
@@ -49,7 +56,7 @@ size_t Dense::output_size() {
|
||||
return out_shape[0];
|
||||
};
|
||||
|
||||
void Dense::set_weights(void *input) {
|
||||
void Dense::set_weights(void* input) {
|
||||
weights.set_data<float>(static_cast<float*>(input));
|
||||
}
|
||||
|
||||
@@ -57,7 +64,7 @@ CUDANet::Tensor& Dense::get_weights() {
|
||||
return weights;
|
||||
}
|
||||
|
||||
void Dense::set_biases(void *input) {
|
||||
void Dense::set_biases(void* input) {
|
||||
biases.set_data<float>(static_cast<float*>(input));
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user