WIP Migrate Dense layer

This commit is contained in:
2025-11-18 21:12:47 +01:00
parent 64eac7050b
commit 7f203b8947
14 changed files with 116 additions and 221 deletions

View File

@@ -22,12 +22,16 @@ Dense::Dense(CUDANet::Backend *backend, CUDANet::Shape input_shape, CUDANet::Sha
auto weights = CUDANet::Tensor{Shape(input_len * output_len), CUDANet::DType::FLOAT32, backend};
auto biases = CUDANet::Tensor(Shape(output_len), CUDANet::DType::FLOAT32, backend);
auto output = CUDANet::Tensor(Shape(output_len), CUDANet::DType::FLOAT32, backend);
weights.zero();
biases.zero();
}
CUDANet::Tensor& Dense::forward(CUDANet::Tensor &input);
CUDANet::Tensor& Dense::forward(CUDANet::Tensor &input) {
backend->dense(weights, biases, input, output, in_shape[0], out_shape[0]);
return output;
}
CUDANet::Shape Dense::input_shape() {
return in_shape;
@@ -45,13 +49,17 @@ size_t Dense::output_size() {
return out_shape[0];
};
void Dense::set_weights(CUDANet::Tensor &input);
void Dense::set_weights(void *input) {
weights.set_data<float>(static_cast<float*>(input));
}
CUDANet::Tensor& Dense::get_weights() {
return weights;
}
void Dense::set_biases(CUDANet::Tensor &input);
void Dense::set_biases(void *input) {
biases.set_data<float>(static_cast<float*>(input));
}
CUDANet::Tensor& Dense::get_biases() {
return biases;