mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-24 07:14:22 +00:00
Fix some dense layer issues
This commit is contained in:
@@ -32,12 +32,12 @@ class Backend {
|
||||
) = 0;
|
||||
|
||||
virtual CUDANet::Tensor& dense(
|
||||
CUDANet::Tensor& weights,
|
||||
CUDANet::Tensor& biases,
|
||||
CUDANet::Tensor& input,
|
||||
const CUDANet::Tensor& weights,
|
||||
const CUDANet::Tensor& biases,
|
||||
const CUDANet::Tensor& input,
|
||||
CUDANet::Tensor& output,
|
||||
size_t input_size,
|
||||
size_t output_size
|
||||
const size_t input_size,
|
||||
const size_t output_size
|
||||
) = 0;
|
||||
};
|
||||
|
||||
|
||||
@@ -29,12 +29,12 @@ class CUDA : public Backend {
|
||||
) override;
|
||||
|
||||
CUDANet::Tensor& dense(
|
||||
CUDANet::Tensor& weights,
|
||||
CUDANet::Tensor& biases,
|
||||
CUDANet::Tensor& input,
|
||||
const CUDANet::Tensor& weights,
|
||||
const CUDANet::Tensor& biases,
|
||||
const CUDANet::Tensor& input,
|
||||
CUDANet::Tensor& output,
|
||||
size_t input_size,
|
||||
size_t output_size
|
||||
const size_t input_size,
|
||||
const size_t output_size
|
||||
) override;
|
||||
};
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class Layer {
|
||||
|
||||
virtual ~Layer(){};
|
||||
|
||||
virtual CUDANet::Tensor& forward(CUDANet::Tensor &input) = 0;
|
||||
virtual CUDANet::Tensor& forward(const CUDANet::Tensor &input) = 0;
|
||||
|
||||
virtual CUDANet::Shape input_shape() = 0;
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class Dense : public Layer {
|
||||
|
||||
~Dense();
|
||||
|
||||
CUDANet::Tensor& forward(CUDANet::Tensor &input) override;
|
||||
CUDANet::Tensor& forward(const CUDANet::Tensor &input) override;
|
||||
|
||||
CUDANet::Shape input_shape() override;
|
||||
|
||||
|
||||
@@ -22,6 +22,12 @@ public:
|
||||
|
||||
Tensor() = default;
|
||||
Tensor(Shape shape, DType dtype, CUDANet::Backend* backend);
|
||||
|
||||
Tensor(Tensor&& other) noexcept;
|
||||
Tensor& operator=(Tensor&& other) noexcept;
|
||||
Tensor(const Tensor&) = delete;
|
||||
Tensor& operator=(const Tensor&) = delete;
|
||||
|
||||
~Tensor();
|
||||
|
||||
size_t size() const;
|
||||
|
||||
Reference in New Issue
Block a user