mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-22 14:24:22 +00:00
WIP Migrate Dense layer
This commit is contained in:
@@ -22,12 +22,16 @@ Dense::Dense(CUDANet::Backend *backend, CUDANet::Shape input_shape, CUDANet::Sha
|
||||
|
||||
auto weights = CUDANet::Tensor{Shape(input_len * output_len), CUDANet::DType::FLOAT32, backend};
|
||||
auto biases = CUDANet::Tensor(Shape(output_len), CUDANet::DType::FLOAT32, backend);
|
||||
auto output = CUDANet::Tensor(Shape(output_len), CUDANet::DType::FLOAT32, backend);
|
||||
|
||||
weights.zero();
|
||||
biases.zero();
|
||||
}
|
||||
|
||||
CUDANet::Tensor& Dense::forward(CUDANet::Tensor &input);
|
||||
CUDANet::Tensor& Dense::forward(CUDANet::Tensor &input) {
|
||||
backend->dense(weights, biases, input, output, in_shape[0], out_shape[0]);
|
||||
return output;
|
||||
}
|
||||
|
||||
CUDANet::Shape Dense::input_shape() {
|
||||
return in_shape;
|
||||
@@ -45,13 +49,17 @@ size_t Dense::output_size() {
|
||||
return out_shape[0];
|
||||
};
|
||||
|
||||
void Dense::set_weights(CUDANet::Tensor &input);
|
||||
void Dense::set_weights(void *input) {
|
||||
weights.set_data<float>(static_cast<float*>(input));
|
||||
}
|
||||
|
||||
CUDANet::Tensor& Dense::get_weights() {
|
||||
return weights;
|
||||
}
|
||||
|
||||
void Dense::set_biases(CUDANet::Tensor &input);
|
||||
void Dense::set_biases(void *input) {
|
||||
biases.set_data<float>(static_cast<float*>(input));
|
||||
}
|
||||
|
||||
CUDANet::Tensor& Dense::get_biases() {
|
||||
return biases;
|
||||
|
||||
Reference in New Issue
Block a user