mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-23 14:54:28 +00:00
Refactor Layer interface to return size of weights and biases instead of Tensor references
This commit is contained in:
@@ -32,11 +32,11 @@ class Layer {
|
||||
|
||||
virtual void set_weights(void *input) = 0;
|
||||
|
||||
virtual CUDANet::Tensor& get_weights() = 0;
|
||||
virtual size_t get_weights_size() = 0;
|
||||
|
||||
virtual void set_biases(void *input) = 0;
|
||||
|
||||
virtual CUDANet::Tensor& get_biases() = 0;
|
||||
virtual size_t get_biases_size() = 0;
|
||||
};
|
||||
|
||||
} // namespace CUDANet::Layers
|
||||
|
||||
@@ -41,11 +41,11 @@ class Activation : public Layer {
|
||||
|
||||
void set_weights(void *input) override;
|
||||
|
||||
CUDANet::Tensor& get_weights() override;
|
||||
size_t get_weights_size() override;
|
||||
|
||||
void set_biases(void *input) override;
|
||||
|
||||
CUDANet::Tensor& get_biases() override;
|
||||
size_t get_biases_size() override;
|
||||
|
||||
|
||||
private:
|
||||
|
||||
@@ -28,11 +28,11 @@ class AvgPool2d : public Layer {
|
||||
|
||||
void set_weights(void* input) override;
|
||||
|
||||
CUDANet::Tensor& get_weights() override;
|
||||
size_t get_weights_size() override;
|
||||
|
||||
void set_biases(void* input) override;
|
||||
|
||||
CUDANet::Tensor& get_biases() override;
|
||||
size_t get_biases_size() override;
|
||||
|
||||
protected:
|
||||
CUDANet::Shape in_shape;
|
||||
|
||||
@@ -22,11 +22,11 @@ class BatchNorm2d : public Layer {
|
||||
|
||||
void set_weights(void* input) override;
|
||||
|
||||
CUDANet::Tensor& get_weights() override;
|
||||
size_t get_weights_size() override;
|
||||
|
||||
void set_biases(void* input) override;
|
||||
|
||||
CUDANet::Tensor& get_biases() override;
|
||||
size_t get_biases_size() override;
|
||||
|
||||
void set_running_mean(void* input);
|
||||
|
||||
|
||||
@@ -32,11 +32,11 @@ class Conv2d : public Layer {
|
||||
|
||||
void set_weights(void* input) override;
|
||||
|
||||
CUDANet::Tensor& get_weights() override;
|
||||
size_t get_weights_size() override;
|
||||
|
||||
void set_biases(void* input) override;
|
||||
|
||||
CUDANet::Tensor& get_biases() override;
|
||||
size_t get_biases_size() override;
|
||||
|
||||
CUDANet::Shape get_padding_shape();
|
||||
|
||||
|
||||
@@ -28,11 +28,11 @@ class Dense : public Layer {
|
||||
|
||||
void set_weights(void *input) override;
|
||||
|
||||
CUDANet::Tensor& get_weights() override;
|
||||
size_t get_weights_size() override;
|
||||
|
||||
void set_biases(void *input) override;
|
||||
|
||||
CUDANet::Tensor& get_biases() override;
|
||||
size_t get_biases_size() override;
|
||||
|
||||
private:
|
||||
CUDANet::Backend *backend;
|
||||
|
||||
@@ -27,11 +27,11 @@ class MaxPool2d : public Layer {
|
||||
|
||||
void set_weights(void *input) override;
|
||||
|
||||
CUDANet::Tensor& get_weights() override;
|
||||
size_t get_weights_size() override;
|
||||
|
||||
void set_biases(void *input) override;
|
||||
|
||||
CUDANet::Tensor& get_biases() override;
|
||||
size_t get_biases_size() override;
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user