Refactor Backend and Layer interfaces

This commit is contained in:
2025-11-18 18:27:57 +01:00
parent 25670f90c4
commit 6340b27055
23 changed files with 154 additions and 201 deletions

47
include/tensor.hpp Normal file
View File

@@ -0,0 +1,47 @@
#pragma once
#include <cstddef>
#include <vector>
#include "backend.hpp"
#include "shape.hpp"
namespace CUDANet
{
enum class DType
{
FLOAT32,
// FLOAT16, // Not implemented yet
// INT32, // Not implemented yet
};
class Tensor
{
public:
Tensor() = default;
Tensor(Shape shape, DType dtype, CUDANet::Backend::IBackend* backend);
~Tensor();
size_t size() const;
size_t numel() const;
template <typename T>
const T* data() const;
template <typename T>
T* data();
private:
Shape shape;
DType dtype;
size_t total_elms;
size_t total_size;
CUDANet::Backend::IBackend* backend;
void* d_ptr;
};
} // namespace CUDANet