Refactor Tensor methods to use void* for data handling and add device_ptr method

This commit is contained in:
2025-11-27 21:18:51 +01:00
parent 9ff214d759
commit c855ae89ec
5 changed files with 24 additions and 28 deletions

View File

@@ -16,8 +16,6 @@ enum class DType
// INT32, // Not implemented yet // INT32, // Not implemented yet
}; };
size_t dtype_size(DType dtype);
class Tensor class Tensor
{ {
public: public:
@@ -38,27 +36,13 @@ public:
size_t size() const; size_t size() const;
size_t numel() const; size_t numel() const;
template <typename T> void* device_ptr();
const T* data() const {
return static_cast<T*>(d_ptr);
}
template <typename T>
T* data() {
return static_cast<T*>(d_ptr);
}
void zero(); void zero();
template <typename T> void fill(int value);
void fill(T value) {
backend->fill(*this, value);
}
template <typename T> void set_data(void *data);
void set_data(T *data) {
backend->copy_to_device(*this, data, total_size);
}
private: private:
Shape shape; Shape shape;

View File

@@ -30,7 +30,7 @@ BatchNorm2d::BatchNorm2d(
this->dtype = dtype; this->dtype = dtype;
epsilon = CUDANet::Tensor({1}, dtype, backend); epsilon = CUDANet::Tensor({1}, dtype, backend);
epsilon.set_data<float>(&eps); epsilon.set_data(&eps);
running_mean = CUDANet::Tensor({in_shape[2]}, dtype, backend); running_mean = CUDANet::Tensor({in_shape[2]}, dtype, backend);
running_mean.zero(); running_mean.zero();
@@ -81,7 +81,7 @@ size_t BatchNorm2d::output_size() {
} }
void BatchNorm2d::set_weights(void* input) { void BatchNorm2d::set_weights(void* input) {
weights.set_data<float>(static_cast<float*>(input)); weights.set_data(input);
} }
size_t BatchNorm2d::get_weights_size() { size_t BatchNorm2d::get_weights_size() {
@@ -89,7 +89,7 @@ size_t BatchNorm2d::get_weights_size() {
} }
void BatchNorm2d::set_biases(void* input) { void BatchNorm2d::set_biases(void* input) {
biases.set_data<float>(static_cast<float*>(input)); biases.set_data(input);
} }
size_t BatchNorm2d::get_biases_size() { size_t BatchNorm2d::get_biases_size() {
@@ -97,7 +97,7 @@ size_t BatchNorm2d::get_biases_size() {
} }
void BatchNorm2d::set_running_mean(void* input) { void BatchNorm2d::set_running_mean(void* input) {
running_mean.set_data<float>(static_cast<float*>(input)); running_mean.set_data(input);
} }
size_t BatchNorm2d::get_running_mean_size() { size_t BatchNorm2d::get_running_mean_size() {
@@ -105,7 +105,7 @@ size_t BatchNorm2d::get_running_mean_size() {
} }
void BatchNorm2d::set_running_var(void* input) { void BatchNorm2d::set_running_var(void* input) {
running_var.set_data<float>(static_cast<float*>(input)); running_var.set_data(input);
} }
size_t BatchNorm2d::get_running_var_size() { size_t BatchNorm2d::get_running_var_size() {

View File

@@ -105,7 +105,7 @@ size_t Conv2d::output_size() {
} }
void Conv2d::set_weights(void* input) { void Conv2d::set_weights(void* input) {
weights.set_data<float>(static_cast<float*>(input)); weights.set_data(input);
} }
size_t Conv2d::get_weights_size() { size_t Conv2d::get_weights_size() {
@@ -113,7 +113,7 @@ size_t Conv2d::get_weights_size() {
} }
void Conv2d::set_biases(void* input) { void Conv2d::set_biases(void* input) {
biases.set_data<float>(static_cast<float*>(input)); biases.set_data(input);
} }
size_t Conv2d::get_biases_size() { size_t Conv2d::get_biases_size() {

View File

@@ -58,7 +58,7 @@ size_t Dense::output_size() {
// TODO: Use dtype // TODO: Use dtype
void Dense::set_weights(void* input) { void Dense::set_weights(void* input) {
weights.set_data<float>(static_cast<float*>(input)); weights.set_data(input);
} }
size_t Dense::get_weights_size() { size_t Dense::get_weights_size() {
@@ -66,7 +66,7 @@ size_t Dense::get_weights_size() {
} }
void Dense::set_biases(void* input) { void Dense::set_biases(void* input) {
biases.set_data<float>(static_cast<float*>(input)); biases.set_data(input);
} }
size_t Dense::get_biases_size() { size_t Dense::get_biases_size() {

View File

@@ -92,6 +92,18 @@ size_t Tensor::size() const {
return total_size; return total_size;
} }
void* Tensor::device_ptr() {
return d_ptr;
}
void Tensor::zero() { void Tensor::zero() {
backend->zero(*this); backend->zero(*this);
} }
void Tensor::fill(int value) {
backend->fill(*this, value);
}
void Tensor::set_data(void *data) {
backend->copy_to_device(*this, data, total_size);
}