mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-22 14:24:22 +00:00
Refactor Tensor methods to use void* for data handling and add device_ptr method
This commit is contained in:
@@ -16,8 +16,6 @@ enum class DType
|
|||||||
// INT32, // Not implemented yet
|
// INT32, // Not implemented yet
|
||||||
};
|
};
|
||||||
|
|
||||||
size_t dtype_size(DType dtype);
|
|
||||||
|
|
||||||
class Tensor
|
class Tensor
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@@ -38,27 +36,13 @@ public:
|
|||||||
size_t size() const;
|
size_t size() const;
|
||||||
size_t numel() const;
|
size_t numel() const;
|
||||||
|
|
||||||
template <typename T>
|
void* device_ptr();
|
||||||
const T* data() const {
|
|
||||||
return static_cast<T*>(d_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
T* data() {
|
|
||||||
return static_cast<T*>(d_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
void zero();
|
void zero();
|
||||||
|
|
||||||
template <typename T>
|
void fill(int value);
|
||||||
void fill(T value) {
|
|
||||||
backend->fill(*this, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
void set_data(void *data);
|
||||||
void set_data(T *data) {
|
|
||||||
backend->copy_to_device(*this, data, total_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Shape shape;
|
Shape shape;
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ BatchNorm2d::BatchNorm2d(
|
|||||||
this->dtype = dtype;
|
this->dtype = dtype;
|
||||||
|
|
||||||
epsilon = CUDANet::Tensor({1}, dtype, backend);
|
epsilon = CUDANet::Tensor({1}, dtype, backend);
|
||||||
epsilon.set_data<float>(&eps);
|
epsilon.set_data(&eps);
|
||||||
|
|
||||||
running_mean = CUDANet::Tensor({in_shape[2]}, dtype, backend);
|
running_mean = CUDANet::Tensor({in_shape[2]}, dtype, backend);
|
||||||
running_mean.zero();
|
running_mean.zero();
|
||||||
@@ -81,7 +81,7 @@ size_t BatchNorm2d::output_size() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void BatchNorm2d::set_weights(void* input) {
|
void BatchNorm2d::set_weights(void* input) {
|
||||||
weights.set_data<float>(static_cast<float*>(input));
|
weights.set_data(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t BatchNorm2d::get_weights_size() {
|
size_t BatchNorm2d::get_weights_size() {
|
||||||
@@ -89,7 +89,7 @@ size_t BatchNorm2d::get_weights_size() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void BatchNorm2d::set_biases(void* input) {
|
void BatchNorm2d::set_biases(void* input) {
|
||||||
biases.set_data<float>(static_cast<float*>(input));
|
biases.set_data(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t BatchNorm2d::get_biases_size() {
|
size_t BatchNorm2d::get_biases_size() {
|
||||||
@@ -97,7 +97,7 @@ size_t BatchNorm2d::get_biases_size() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void BatchNorm2d::set_running_mean(void* input) {
|
void BatchNorm2d::set_running_mean(void* input) {
|
||||||
running_mean.set_data<float>(static_cast<float*>(input));
|
running_mean.set_data(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t BatchNorm2d::get_running_mean_size() {
|
size_t BatchNorm2d::get_running_mean_size() {
|
||||||
@@ -105,7 +105,7 @@ size_t BatchNorm2d::get_running_mean_size() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void BatchNorm2d::set_running_var(void* input) {
|
void BatchNorm2d::set_running_var(void* input) {
|
||||||
running_var.set_data<float>(static_cast<float*>(input));
|
running_var.set_data(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t BatchNorm2d::get_running_var_size() {
|
size_t BatchNorm2d::get_running_var_size() {
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ size_t Conv2d::output_size() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Conv2d::set_weights(void* input) {
|
void Conv2d::set_weights(void* input) {
|
||||||
weights.set_data<float>(static_cast<float*>(input));
|
weights.set_data(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t Conv2d::get_weights_size() {
|
size_t Conv2d::get_weights_size() {
|
||||||
@@ -113,7 +113,7 @@ size_t Conv2d::get_weights_size() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Conv2d::set_biases(void* input) {
|
void Conv2d::set_biases(void* input) {
|
||||||
biases.set_data<float>(static_cast<float*>(input));
|
biases.set_data(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t Conv2d::get_biases_size() {
|
size_t Conv2d::get_biases_size() {
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ size_t Dense::output_size() {
|
|||||||
|
|
||||||
// TODO: Use dtype
|
// TODO: Use dtype
|
||||||
void Dense::set_weights(void* input) {
|
void Dense::set_weights(void* input) {
|
||||||
weights.set_data<float>(static_cast<float*>(input));
|
weights.set_data(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t Dense::get_weights_size() {
|
size_t Dense::get_weights_size() {
|
||||||
@@ -66,7 +66,7 @@ size_t Dense::get_weights_size() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Dense::set_biases(void* input) {
|
void Dense::set_biases(void* input) {
|
||||||
biases.set_data<float>(static_cast<float*>(input));
|
biases.set_data(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t Dense::get_biases_size() {
|
size_t Dense::get_biases_size() {
|
||||||
|
|||||||
@@ -92,6 +92,18 @@ size_t Tensor::size() const {
|
|||||||
return total_size;
|
return total_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void* Tensor::device_ptr() {
|
||||||
|
return d_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
void Tensor::zero() {
|
void Tensor::zero() {
|
||||||
backend->zero(*this);
|
backend->zero(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Tensor::fill(int value) {
|
||||||
|
backend->fill(*this, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Tensor::set_data(void *data) {
|
||||||
|
backend->copy_to_device(*this, data, total_size);
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user