Refactor Tensor methods to use void* for data handling and add device_ptr method

This commit is contained in:
2025-11-27 21:18:51 +01:00
parent 9ff214d759
commit c855ae89ec
5 changed files with 24 additions and 28 deletions

View File

@@ -30,7 +30,7 @@ BatchNorm2d::BatchNorm2d(
this->dtype = dtype;
epsilon = CUDANet::Tensor({1}, dtype, backend);
epsilon.set_data<float>(&eps);
epsilon.set_data(&eps);
running_mean = CUDANet::Tensor({in_shape[2]}, dtype, backend);
running_mean.zero();
@@ -81,7 +81,7 @@ size_t BatchNorm2d::output_size() {
}
void BatchNorm2d::set_weights(void* input) {
weights.set_data<float>(static_cast<float*>(input));
weights.set_data(input);
}
size_t BatchNorm2d::get_weights_size() {
@@ -89,7 +89,7 @@ size_t BatchNorm2d::get_weights_size() {
}
void BatchNorm2d::set_biases(void* input) {
biases.set_data<float>(static_cast<float*>(input));
biases.set_data(input);
}
size_t BatchNorm2d::get_biases_size() {
@@ -97,7 +97,7 @@ size_t BatchNorm2d::get_biases_size() {
}
void BatchNorm2d::set_running_mean(void* input) {
running_mean.set_data<float>(static_cast<float*>(input));
running_mean.set_data(input);
}
size_t BatchNorm2d::get_running_mean_size() {
@@ -105,7 +105,7 @@ size_t BatchNorm2d::get_running_mean_size() {
}
void BatchNorm2d::set_running_var(void* input) {
running_var.set_data<float>(static_cast<float*>(input));
running_var.set_data(input);
}
size_t BatchNorm2d::get_running_var_size() {

View File

@@ -105,7 +105,7 @@ size_t Conv2d::output_size() {
}
void Conv2d::set_weights(void* input) {
weights.set_data<float>(static_cast<float*>(input));
weights.set_data(input);
}
size_t Conv2d::get_weights_size() {
@@ -113,7 +113,7 @@ size_t Conv2d::get_weights_size() {
}
void Conv2d::set_biases(void* input) {
biases.set_data<float>(static_cast<float*>(input));
biases.set_data(input);
}
size_t Conv2d::get_biases_size() {

View File

@@ -58,7 +58,7 @@ size_t Dense::output_size() {
// TODO: Use dtype
void Dense::set_weights(void* input) {
weights.set_data<float>(static_cast<float*>(input));
weights.set_data(input);
}
size_t Dense::get_weights_size() {
@@ -66,7 +66,7 @@ size_t Dense::get_weights_size() {
}
void Dense::set_biases(void* input) {
biases.set_data<float>(static_cast<float*>(input));
biases.set_data(input);
}
size_t Dense::get_biases_size() {

View File

@@ -92,6 +92,18 @@ size_t Tensor::size() const {
return total_size;
}
void* Tensor::device_ptr() {
return d_ptr;
}
void Tensor::zero() {
backend->zero(*this);
}
void Tensor::fill(int value) {
backend->fill(*this, value);
}
void Tensor::set_data(void *data) {
backend->copy_to_device(*this, data, total_size);
}