Fix some dense layer issues

This commit is contained in:
2025-11-18 22:17:08 +01:00
parent 7f203b8947
commit 4c26efe826
8 changed files with 110 additions and 44 deletions

View File

@@ -32,12 +32,12 @@ class Backend {
) = 0;
virtual CUDANet::Tensor& dense(
CUDANet::Tensor& weights,
CUDANet::Tensor& biases,
CUDANet::Tensor& input,
const CUDANet::Tensor& weights,
const CUDANet::Tensor& biases,
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
size_t input_size,
size_t output_size
const size_t input_size,
const size_t output_size
) = 0;
};

View File

@@ -29,12 +29,12 @@ class CUDA : public Backend {
) override;
CUDANet::Tensor& dense(
CUDANet::Tensor& weights,
CUDANet::Tensor& biases,
CUDANet::Tensor& input,
const CUDANet::Tensor& weights,
const CUDANet::Tensor& biases,
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
size_t input_size,
size_t output_size
const size_t input_size,
const size_t output_size
) override;
};

View File

@@ -20,7 +20,7 @@ class Layer {
virtual ~Layer(){};
virtual CUDANet::Tensor& forward(CUDANet::Tensor &input) = 0;
virtual CUDANet::Tensor& forward(const CUDANet::Tensor &input) = 0;
virtual CUDANet::Shape input_shape() = 0;

View File

@@ -18,7 +18,7 @@ class Dense : public Layer {
~Dense();
CUDANet::Tensor& forward(CUDANet::Tensor &input) override;
CUDANet::Tensor& forward(const CUDANet::Tensor &input) override;
CUDANet::Shape input_shape() override;

View File

@@ -22,6 +22,12 @@ public:
Tensor() = default;
Tensor(Shape shape, DType dtype, CUDANet::Backend* backend);
Tensor(Tensor&& other) noexcept;
Tensor& operator=(Tensor&& other) noexcept;
Tensor(const Tensor&) = delete;
Tensor& operator=(const Tensor&) = delete;
~Tensor();
size_t size() const;