Fix Tensor issues

This commit is contained in:
2025-11-18 22:38:56 +01:00
parent 4c26efe826
commit 10c84d75fc
4 changed files with 21 additions and 25 deletions

View File

@@ -2,10 +2,11 @@
#include <cstddef> #include <cstddef>
#include "tensor.hpp"
namespace CUDANet { namespace CUDANet {
// Forward declaration
class Tensor;
class Backend { class Backend {
public: public:
// Memory management // Memory management

View File

@@ -34,15 +34,21 @@ public:
size_t numel() const; size_t numel() const;
template <typename T> template <typename T>
const T* data() const; const T* data() const {
return static_cast<T*>(d_ptr);
}
template <typename T> template <typename T>
T* data(); T* data() {
return static_cast<T*>(d_ptr);
}
void zero(); void zero();
template <typename T> template <typename T>
void set_data(T *data); void set_data(T *data) {
backend->copy_to_device(*this, data, total_size);
}
private: private:
Shape shape; Shape shape;

View File

@@ -10,10 +10,10 @@ Dense::Dense(CUDANet::Backend* backend, CUDANet::Shape in, CUDANet::Shape out)
in_shape(in), in_shape(in),
out_shape(out), out_shape(out),
weights( weights(
CUDANet::Tensor{{in[0] * out[0]}, CUDANet::DType::FLOAT32, backend} CUDANet::Tensor(Shape{in[0] * out[0]}, CUDANet::DType::FLOAT32, backend)
), ),
biases(CUDANet::Tensor({out[0]}, CUDANet::DType::FLOAT32, backend)), biases(CUDANet::Tensor(Shape{out[0]}, CUDANet::DType::FLOAT32, backend)),
output(CUDANet::Tensor({out[0]}, CUDANet::DType::FLOAT32, backend)) { output(CUDANet::Tensor(Shape{out[0]}, CUDANet::DType::FLOAT32, backend)) {
// Allocate memory for weights and biases // Allocate memory for weights and biases
if (in.size() != 1) { if (in.size() != 1) {
@@ -35,6 +35,8 @@ Dense::Dense(CUDANet::Backend* backend, CUDANet::Shape in, CUDANet::Shape out)
biases.zero(); biases.zero();
} }
Dense::~Dense() {}
CUDANet::Tensor& Dense::forward(const CUDANet::Tensor& input) { CUDANet::Tensor& Dense::forward(const CUDANet::Tensor& input) {
backend->dense(weights, biases, input, output, in_shape[0], out_shape[0]); backend->dense(weights, biases, input, output, in_shape[0], out_shape[0]);
return output; return output;

View File

@@ -68,9 +68,11 @@ Tensor& Tensor::operator=(Tensor&& other) noexcept {
} }
Tensor::~Tensor() { Tensor::~Tensor() {
if (backend && d_ptr) {
backend->deallocate(d_ptr); backend->deallocate(d_ptr);
d_ptr = nullptr; d_ptr = nullptr;
} }
}
size_t Tensor::numel() const { size_t Tensor::numel() const {
return total_elms; return total_elms;
@@ -80,21 +82,6 @@ size_t Tensor::size() const {
return total_size; return total_size;
} }
template <typename T>
const T* Tensor::data() const {
return static_cast<T*>(d_ptr);
}
template <typename T>
T* Tensor::data() {
return static_cast<T*>(d_ptr);
}
void Tensor::zero() { void Tensor::zero() {
backend->zero(*this); backend->zero(*this);
} }
template <typename T>
void Tensor::set_data(T *data) {
backend->copy_to_device(*this, data, total_size);
}