mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-22 22:34:22 +00:00
Refactor Tensor methods to use void* for data handling and add device_ptr method
This commit is contained in:
@@ -16,8 +16,6 @@ enum class DType
|
||||
// INT32, // Not implemented yet
|
||||
};
|
||||
|
||||
size_t dtype_size(DType dtype);
|
||||
|
||||
class Tensor
|
||||
{
|
||||
public:
|
||||
@@ -38,27 +36,13 @@ public:
|
||||
size_t size() const;
|
||||
size_t numel() const;
|
||||
|
||||
template <typename T>
|
||||
const T* data() const {
|
||||
return static_cast<T*>(d_ptr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* data() {
|
||||
return static_cast<T*>(d_ptr);
|
||||
}
|
||||
void* device_ptr();
|
||||
|
||||
void zero();
|
||||
|
||||
template <typename T>
|
||||
void fill(T value) {
|
||||
backend->fill(*this, value);
|
||||
}
|
||||
void fill(int value);
|
||||
|
||||
template <typename T>
|
||||
void set_data(T *data) {
|
||||
backend->copy_to_device(*this, data, total_size);
|
||||
}
|
||||
void set_data(void *data);
|
||||
|
||||
private:
|
||||
Shape shape;
|
||||
|
||||
Reference in New Issue
Block a user