diff --git a/include/tensor.hpp b/include/tensor.hpp index 0fb3e2a..40e42ea 100644 --- a/include/tensor.hpp +++ b/include/tensor.hpp @@ -16,8 +16,6 @@ enum class DType // INT32, // Not implemented yet }; -size_t dtype_size(DType dtype); - class Tensor { public: @@ -38,27 +36,13 @@ public: size_t size() const; size_t numel() const; - template - const T* data() const { - return static_cast(d_ptr); - } - - template - T* data() { - return static_cast(d_ptr); - } + void* device_ptr(); void zero(); - template - void fill(T value) { - backend->fill(*this, value); - } + void fill(int value); - template - void set_data(T *data) { - backend->copy_to_device(*this, data, total_size); - } + void set_data(void *data); private: Shape shape; diff --git a/src/layers/batch_norm.cpp b/src/layers/batch_norm.cpp index 6770f84..dc0a6e3 100644 --- a/src/layers/batch_norm.cpp +++ b/src/layers/batch_norm.cpp @@ -30,7 +30,7 @@ BatchNorm2d::BatchNorm2d( this->dtype = dtype; epsilon = CUDANet::Tensor({1}, dtype, backend); - epsilon.set_data(&eps); + epsilon.set_data(&eps); running_mean = CUDANet::Tensor({in_shape[2]}, dtype, backend); running_mean.zero(); @@ -81,7 +81,7 @@ size_t BatchNorm2d::output_size() { } void BatchNorm2d::set_weights(void* input) { - weights.set_data(static_cast(input)); + weights.set_data(input); } size_t BatchNorm2d::get_weights_size() { @@ -89,7 +89,7 @@ size_t BatchNorm2d::get_weights_size() { } void BatchNorm2d::set_biases(void* input) { - biases.set_data(static_cast(input)); + biases.set_data(input); } size_t BatchNorm2d::get_biases_size() { @@ -97,7 +97,7 @@ size_t BatchNorm2d::get_biases_size() { } void BatchNorm2d::set_running_mean(void* input) { - running_mean.set_data(static_cast(input)); + running_mean.set_data(input); } size_t BatchNorm2d::get_running_mean_size() { @@ -105,7 +105,7 @@ size_t BatchNorm2d::get_running_mean_size() { } void BatchNorm2d::set_running_var(void* input) { - running_var.set_data(static_cast(input)); + running_var.set_data(input); } size_t BatchNorm2d::get_running_var_size() { diff --git a/src/layers/conv2d.cpp b/src/layers/conv2d.cpp index f297ebe..111cb40 100644 --- a/src/layers/conv2d.cpp +++ b/src/layers/conv2d.cpp @@ -105,7 +105,7 @@ size_t Conv2d::output_size() { } void Conv2d::set_weights(void* input) { - weights.set_data(static_cast(input)); + weights.set_data(input); } size_t Conv2d::get_weights_size() { @@ -113,7 +113,7 @@ size_t Conv2d::get_weights_size() { } void Conv2d::set_biases(void* input) { - biases.set_data(static_cast(input)); + biases.set_data(input); } size_t Conv2d::get_biases_size() { diff --git a/src/layers/dense.cpp b/src/layers/dense.cpp index 54b02fe..db160c9 100644 --- a/src/layers/dense.cpp +++ b/src/layers/dense.cpp @@ -58,7 +58,7 @@ size_t Dense::output_size() { // TODO: Use dtype void Dense::set_weights(void* input) { - weights.set_data(static_cast(input)); + weights.set_data(input); } size_t Dense::get_weights_size() { @@ -66,7 +66,7 @@ size_t Dense::get_weights_size() { } void Dense::set_biases(void* input) { - biases.set_data(static_cast(input)); + biases.set_data(input); } size_t Dense::get_biases_size() { diff --git a/src/tensor.cpp b/src/tensor.cpp index 16ee4de..70f5d9d 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -92,6 +92,18 @@ size_t Tensor::size() const { return total_size; } +void* Tensor::device_ptr() { + return d_ptr; +} + void Tensor::zero() { backend->zero(*this); } + +void Tensor::fill(int value) { + backend->fill(*this, value); +} + +void Tensor::set_data(void *data) { + backend->copy_to_device(*this, data, total_size); +}