Refactor Layer interface to return size of weights and biases instead of Tensor references

This commit is contained in:
2025-11-23 20:44:25 +01:00
parent 547cd0c224
commit 9f1a56c699
14 changed files with 48 additions and 36 deletions

View File

@@ -32,11 +32,11 @@ class Layer {
virtual void set_weights(void *input) = 0;
virtual CUDANet::Tensor& get_weights() = 0;
virtual size_t get_weights_size() = 0;
virtual void set_biases(void *input) = 0;
virtual CUDANet::Tensor& get_biases() = 0;
virtual size_t get_biases_size() = 0;
};
} // namespace CUDANet::Layers

View File

@@ -41,11 +41,11 @@ class Activation : public Layer {
void set_weights(void *input) override;
CUDANet::Tensor& get_weights() override;
size_t get_weights_size() override;
void set_biases(void *input) override;
CUDANet::Tensor& get_biases() override;
size_t get_biases_size() override;
private:

View File

@@ -28,11 +28,11 @@ class AvgPool2d : public Layer {
void set_weights(void* input) override;
CUDANet::Tensor& get_weights() override;
size_t get_weights_size() override;
void set_biases(void* input) override;
CUDANet::Tensor& get_biases() override;
size_t get_biases_size() override;
protected:
CUDANet::Shape in_shape;

View File

@@ -22,11 +22,11 @@ class BatchNorm2d : public Layer {
void set_weights(void* input) override;
CUDANet::Tensor& get_weights() override;
size_t get_weights_size() override;
void set_biases(void* input) override;
CUDANet::Tensor& get_biases() override;
size_t get_biases_size() override;
void set_running_mean(void* input);

View File

@@ -32,11 +32,11 @@ class Conv2d : public Layer {
void set_weights(void* input) override;
CUDANet::Tensor& get_weights() override;
size_t get_weights_size() override;
void set_biases(void* input) override;
CUDANet::Tensor& get_biases() override;
size_t get_biases_size() override;
CUDANet::Shape get_padding_shape();

View File

@@ -28,11 +28,11 @@ class Dense : public Layer {
void set_weights(void *input) override;
CUDANet::Tensor& get_weights() override;
size_t get_weights_size() override;
void set_biases(void *input) override;
CUDANet::Tensor& get_biases() override;
size_t get_biases_size() override;
private:
CUDANet::Backend *backend;

View File

@@ -27,11 +27,11 @@ class MaxPool2d : public Layer {
void set_weights(void *input) override;
CUDANet::Tensor& get_weights() override;
size_t get_weights_size() override;
void set_biases(void *input) override;
CUDANet::Tensor& get_biases() override;
size_t get_biases_size() override;

View File

@@ -59,8 +59,12 @@ size_t Activation::output_size() {
void Activation::set_weights(void *input) {}
CUDANet::Tensor& Activation::get_weights() {}
size_t Activation::get_weights_size() {
return 0;
}
void Activation::set_biases(void *input) {}
CUDANet::Tensor& Activation::get_biases() {}
size_t Activation::get_biases_size() {
return 0;
}

View File

@@ -81,11 +81,15 @@ size_t AvgPool2d::output_size() {
void AvgPool2d::set_weights(void* input) {}
CUDANet::Tensor& AvgPool2d::get_weights() {}
size_t AvgPool2d::get_weights_size() {
return 0;
}
void AvgPool2d::set_biases(void* input) {}
CUDANet::Tensor& AvgPool2d::get_biases() {}
size_t AvgPool2d::get_biases_size() {
return 0;
}
AdaptiveAvgPool2d::AdaptiveAvgPool2d(

View File

@@ -74,16 +74,16 @@ void BatchNorm2d::set_weights(void* input) {
weights.set_data<float>(static_cast<float*>(input));
}
CUDANet::Tensor& BatchNorm2d::get_weights() {
return weights;
size_t BatchNorm2d::get_weights_size() {
return weights.size();
}
void BatchNorm2d::set_biases(void* input) {
biases.set_data<float>(static_cast<float*>(input));
}
CUDANet::Tensor& BatchNorm2d::get_biases() {
return biases;
size_t BatchNorm2d::get_biases_size() {
return biases.size();
}
void BatchNorm2d::set_running_mean(void* input) {

View File

@@ -96,16 +96,16 @@ void Conv2d::set_weights(void* input) {
weights.set_data<float>(static_cast<float*>(input));
}
CUDANet::Tensor& Conv2d::get_weights() {
return weights;
size_t Conv2d::get_weights_size() {
return weights.size();
}
void Conv2d::set_biases(void* input) {
biases.set_data<float>(static_cast<float*>(input));
}
CUDANet::Tensor& Conv2d::get_biases() {
return biases;
size_t Conv2d::get_biases_size() {
return biases.size();
}
CUDANet::Shape Conv2d::get_padding_shape() {

View File

@@ -55,14 +55,14 @@ void Dense::set_weights(void* input) {
weights.set_data<float>(static_cast<float*>(input));
}
CUDANet::Tensor& Dense::get_weights() {
return weights;
size_t Dense::get_weights_size() {
return weights.size();
}
void Dense::set_biases(void* input) {
biases.set_data<float>(static_cast<float*>(input));
}
CUDANet::Tensor& Dense::get_biases() {
return biases;
size_t Dense::get_biases_size() {
return biases.size();
}

View File

@@ -75,8 +75,12 @@ size_t MaxPool2d::output_size() {
void MaxPool2d::set_weights(void* input) {}
CUDANet::Tensor& MaxPool2d::get_weights() {}
size_t MaxPool2d::get_weights_size() {
return 0;
}
void MaxPool2d::set_biases(void* input) {}
CUDANet::Tensor& MaxPool2d::get_biases() {}
size_t MaxPool2d::get_biases_size() {
return 0;
}

View File

@@ -128,20 +128,20 @@ void Model::load_weights(const std::string& path) {
Layer* layer = layer_map[tensor_info.name];
if (tensor_info.type == TensorType::WEIGHT) {
if (layer->get_weights().size() != values.size()) {
if (layer->get_weights_size() != values.size()) {
std::cerr << "Layer: " << tensor_info.name
<< " has incorrect number of weights, expected "
<< layer->get_weights().size() << " but got "
<< layer->get_weights_size() << " but got "
<< values.size() << ", skipping" << std::endl;
continue;
}
layer->set_weights(values.data());
} else if (tensor_info.type == TensorType::BIAS) {
if (layer->get_biases().size() != values.size()) {
if (layer->get_biases_size() != values.size()) {
std::cerr << "Layer: " << tensor_info.name
<< " has incorrect number of biases, expected "
<< layer->get_biases().size() << " but got "
<< layer->get_biases_size() << " but got "
<< values.size() << ", skipping" << std::endl;
continue;
}