Refactor Tensor methods to use void* for data handling and add device_ptr method

This commit is contained in:
2025-11-27 21:18:51 +01:00
parent 9ff214d759
commit c855ae89ec
5 changed files with 24 additions and 28 deletions

View File

@@ -16,8 +16,6 @@ enum class DType
// INT32, // Not implemented yet
};
size_t dtype_size(DType dtype);
class Tensor
{
public:
@@ -38,27 +36,13 @@ public:
size_t size() const;
size_t numel() const;
template <typename T>
const T* data() const {
return static_cast<T*>(d_ptr);
}
template <typename T>
T* data() {
return static_cast<T*>(d_ptr);
}
void* device_ptr();
void zero();
template <typename T>
void fill(T value) {
backend->fill(*this, value);
}
void fill(int value);
template <typename T>
void set_data(T *data) {
backend->copy_to_device(*this, data, total_size);
}
void set_data(void *data);
private:
Shape shape;