mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-22 14:24:22 +00:00
Fix Tensor issues
This commit is contained in:
@@ -2,10 +2,11 @@
|
|||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
|
||||||
#include "tensor.hpp"
|
|
||||||
|
|
||||||
namespace CUDANet {
|
namespace CUDANet {
|
||||||
|
|
||||||
|
// Forward declaration
|
||||||
|
class Tensor;
|
||||||
|
|
||||||
class Backend {
|
class Backend {
|
||||||
public:
|
public:
|
||||||
// Memory management
|
// Memory management
|
||||||
|
|||||||
@@ -34,15 +34,21 @@ public:
|
|||||||
size_t numel() const;
|
size_t numel() const;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
const T* data() const;
|
const T* data() const {
|
||||||
|
return static_cast<T*>(d_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T* data();
|
T* data() {
|
||||||
|
return static_cast<T*>(d_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
void zero();
|
void zero();
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void set_data(T *data);
|
void set_data(T *data) {
|
||||||
|
backend->copy_to_device(*this, data, total_size);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Shape shape;
|
Shape shape;
|
||||||
|
|||||||
@@ -10,10 +10,10 @@ Dense::Dense(CUDANet::Backend* backend, CUDANet::Shape in, CUDANet::Shape out)
|
|||||||
in_shape(in),
|
in_shape(in),
|
||||||
out_shape(out),
|
out_shape(out),
|
||||||
weights(
|
weights(
|
||||||
CUDANet::Tensor{{in[0] * out[0]}, CUDANet::DType::FLOAT32, backend}
|
CUDANet::Tensor(Shape{in[0] * out[0]}, CUDANet::DType::FLOAT32, backend)
|
||||||
),
|
),
|
||||||
biases(CUDANet::Tensor({out[0]}, CUDANet::DType::FLOAT32, backend)),
|
biases(CUDANet::Tensor(Shape{out[0]}, CUDANet::DType::FLOAT32, backend)),
|
||||||
output(CUDANet::Tensor({out[0]}, CUDANet::DType::FLOAT32, backend)) {
|
output(CUDANet::Tensor(Shape{out[0]}, CUDANet::DType::FLOAT32, backend)) {
|
||||||
// Allocate memory for weights and biases
|
// Allocate memory for weights and biases
|
||||||
|
|
||||||
if (in.size() != 1) {
|
if (in.size() != 1) {
|
||||||
@@ -35,6 +35,8 @@ Dense::Dense(CUDANet::Backend* backend, CUDANet::Shape in, CUDANet::Shape out)
|
|||||||
biases.zero();
|
biases.zero();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Dense::~Dense() {}
|
||||||
|
|
||||||
CUDANet::Tensor& Dense::forward(const CUDANet::Tensor& input) {
|
CUDANet::Tensor& Dense::forward(const CUDANet::Tensor& input) {
|
||||||
backend->dense(weights, biases, input, output, in_shape[0], out_shape[0]);
|
backend->dense(weights, biases, input, output, in_shape[0], out_shape[0]);
|
||||||
return output;
|
return output;
|
||||||
|
|||||||
@@ -68,9 +68,11 @@ Tensor& Tensor::operator=(Tensor&& other) noexcept {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Tensor::~Tensor() {
|
Tensor::~Tensor() {
|
||||||
|
if (backend && d_ptr) {
|
||||||
backend->deallocate(d_ptr);
|
backend->deallocate(d_ptr);
|
||||||
d_ptr = nullptr;
|
d_ptr = nullptr;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
size_t Tensor::numel() const {
|
size_t Tensor::numel() const {
|
||||||
return total_elms;
|
return total_elms;
|
||||||
@@ -80,21 +82,6 @@ size_t Tensor::size() const {
|
|||||||
return total_size;
|
return total_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
const T* Tensor::data() const {
|
|
||||||
return static_cast<T*>(d_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
T* Tensor::data() {
|
|
||||||
return static_cast<T*>(d_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Tensor::zero() {
|
void Tensor::zero() {
|
||||||
backend->zero(*this);
|
backend->zero(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void Tensor::set_data(T *data) {
|
|
||||||
backend->copy_to_device(*this, data, total_size);
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user