mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-22 14:24:22 +00:00
Refactor Tensor methods to use void* for data handling and add device_ptr method
This commit is contained in:
@@ -16,8 +16,6 @@ enum class DType
|
||||
// INT32, // Not implemented yet
|
||||
};
|
||||
|
||||
size_t dtype_size(DType dtype);
|
||||
|
||||
class Tensor
|
||||
{
|
||||
public:
|
||||
@@ -38,27 +36,13 @@ public:
|
||||
size_t size() const;
|
||||
size_t numel() const;
|
||||
|
||||
template <typename T>
|
||||
const T* data() const {
|
||||
return static_cast<T*>(d_ptr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* data() {
|
||||
return static_cast<T*>(d_ptr);
|
||||
}
|
||||
void* device_ptr();
|
||||
|
||||
void zero();
|
||||
|
||||
template <typename T>
|
||||
void fill(T value) {
|
||||
backend->fill(*this, value);
|
||||
}
|
||||
void fill(int value);
|
||||
|
||||
template <typename T>
|
||||
void set_data(T *data) {
|
||||
backend->copy_to_device(*this, data, total_size);
|
||||
}
|
||||
void set_data(void *data);
|
||||
|
||||
private:
|
||||
Shape shape;
|
||||
|
||||
@@ -30,7 +30,7 @@ BatchNorm2d::BatchNorm2d(
|
||||
this->dtype = dtype;
|
||||
|
||||
epsilon = CUDANet::Tensor({1}, dtype, backend);
|
||||
epsilon.set_data<float>(&eps);
|
||||
epsilon.set_data(&eps);
|
||||
|
||||
running_mean = CUDANet::Tensor({in_shape[2]}, dtype, backend);
|
||||
running_mean.zero();
|
||||
@@ -81,7 +81,7 @@ size_t BatchNorm2d::output_size() {
|
||||
}
|
||||
|
||||
void BatchNorm2d::set_weights(void* input) {
|
||||
weights.set_data<float>(static_cast<float*>(input));
|
||||
weights.set_data(input);
|
||||
}
|
||||
|
||||
size_t BatchNorm2d::get_weights_size() {
|
||||
@@ -89,7 +89,7 @@ size_t BatchNorm2d::get_weights_size() {
|
||||
}
|
||||
|
||||
void BatchNorm2d::set_biases(void* input) {
|
||||
biases.set_data<float>(static_cast<float*>(input));
|
||||
biases.set_data(input);
|
||||
}
|
||||
|
||||
size_t BatchNorm2d::get_biases_size() {
|
||||
@@ -97,7 +97,7 @@ size_t BatchNorm2d::get_biases_size() {
|
||||
}
|
||||
|
||||
void BatchNorm2d::set_running_mean(void* input) {
|
||||
running_mean.set_data<float>(static_cast<float*>(input));
|
||||
running_mean.set_data(input);
|
||||
}
|
||||
|
||||
size_t BatchNorm2d::get_running_mean_size() {
|
||||
@@ -105,7 +105,7 @@ size_t BatchNorm2d::get_running_mean_size() {
|
||||
}
|
||||
|
||||
void BatchNorm2d::set_running_var(void* input) {
|
||||
running_var.set_data<float>(static_cast<float*>(input));
|
||||
running_var.set_data(input);
|
||||
}
|
||||
|
||||
size_t BatchNorm2d::get_running_var_size() {
|
||||
|
||||
@@ -105,7 +105,7 @@ size_t Conv2d::output_size() {
|
||||
}
|
||||
|
||||
void Conv2d::set_weights(void* input) {
|
||||
weights.set_data<float>(static_cast<float*>(input));
|
||||
weights.set_data(input);
|
||||
}
|
||||
|
||||
size_t Conv2d::get_weights_size() {
|
||||
@@ -113,7 +113,7 @@ size_t Conv2d::get_weights_size() {
|
||||
}
|
||||
|
||||
void Conv2d::set_biases(void* input) {
|
||||
biases.set_data<float>(static_cast<float*>(input));
|
||||
biases.set_data(input);
|
||||
}
|
||||
|
||||
size_t Conv2d::get_biases_size() {
|
||||
|
||||
@@ -58,7 +58,7 @@ size_t Dense::output_size() {
|
||||
|
||||
// TODO: Use dtype
|
||||
void Dense::set_weights(void* input) {
|
||||
weights.set_data<float>(static_cast<float*>(input));
|
||||
weights.set_data(input);
|
||||
}
|
||||
|
||||
size_t Dense::get_weights_size() {
|
||||
@@ -66,7 +66,7 @@ size_t Dense::get_weights_size() {
|
||||
}
|
||||
|
||||
void Dense::set_biases(void* input) {
|
||||
biases.set_data<float>(static_cast<float*>(input));
|
||||
biases.set_data(input);
|
||||
}
|
||||
|
||||
size_t Dense::get_biases_size() {
|
||||
|
||||
@@ -92,6 +92,18 @@ size_t Tensor::size() const {
|
||||
return total_size;
|
||||
}
|
||||
|
||||
void* Tensor::device_ptr() {
|
||||
return d_ptr;
|
||||
}
|
||||
|
||||
void Tensor::zero() {
|
||||
backend->zero(*this);
|
||||
}
|
||||
|
||||
void Tensor::fill(int value) {
|
||||
backend->fill(*this, value);
|
||||
}
|
||||
|
||||
void Tensor::set_data(void *data) {
|
||||
backend->copy_to_device(*this, data, total_size);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user