Refactor Layer interface to return size of weights and biases instead of Tensor references

This commit is contained in:
2025-11-23 20:44:25 +01:00
parent 547cd0c224
commit 9f1a56c699
14 changed files with 48 additions and 36 deletions

View File

@@ -32,11 +32,11 @@ class Layer {
virtual void set_weights(void *input) = 0;
virtual CUDANet::Tensor& get_weights() = 0;
virtual size_t get_weights_size() = 0;
virtual void set_biases(void *input) = 0;
virtual CUDANet::Tensor& get_biases() = 0;
virtual size_t get_biases_size() = 0;
};
} // namespace CUDANet::Layers

View File

@@ -41,11 +41,11 @@ class Activation : public Layer {
void set_weights(void *input) override;
CUDANet::Tensor& get_weights() override;
size_t get_weights_size() override;
void set_biases(void *input) override;
CUDANet::Tensor& get_biases() override;
size_t get_biases_size() override;
private:

View File

@@ -28,11 +28,11 @@ class AvgPool2d : public Layer {
void set_weights(void* input) override;
CUDANet::Tensor& get_weights() override;
size_t get_weights_size() override;
void set_biases(void* input) override;
CUDANet::Tensor& get_biases() override;
size_t get_biases_size() override;
protected:
CUDANet::Shape in_shape;

View File

@@ -22,11 +22,11 @@ class BatchNorm2d : public Layer {
void set_weights(void* input) override;
CUDANet::Tensor& get_weights() override;
size_t get_weights_size() override;
void set_biases(void* input) override;
CUDANet::Tensor& get_biases() override;
size_t get_biases_size() override;
void set_running_mean(void* input);

View File

@@ -32,11 +32,11 @@ class Conv2d : public Layer {
void set_weights(void* input) override;
CUDANet::Tensor& get_weights() override;
size_t get_weights_size() override;
void set_biases(void* input) override;
CUDANet::Tensor& get_biases() override;
size_t get_biases_size() override;
CUDANet::Shape get_padding_shape();

View File

@@ -28,11 +28,11 @@ class Dense : public Layer {
void set_weights(void *input) override;
CUDANet::Tensor& get_weights() override;
size_t get_weights_size() override;
void set_biases(void *input) override;
CUDANet::Tensor& get_biases() override;
size_t get_biases_size() override;
private:
CUDANet::Backend *backend;

View File

@@ -27,11 +27,11 @@ class MaxPool2d : public Layer {
void set_weights(void *input) override;
CUDANet::Tensor& get_weights() override;
size_t get_weights_size() override;
void set_biases(void *input) override;
CUDANet::Tensor& get_biases() override;
size_t get_biases_size() override;