Compare commits

...

3 Commits

Author SHA1 Message Date
13d3d38b68 Add dtype parameter to layer constructors 2025-11-26 00:19:33 +01:00
84153ac49c Add default dtype to backend 2025-11-25 23:42:19 +01:00
ad079560ff Update CMakeLists.txt 2025-11-25 19:08:55 +01:00
24 changed files with 243 additions and 74 deletions

View File

@@ -23,8 +23,8 @@ endif()
file(GLOB_RECURSE CPU_SOURCES
src/*.cpp
src/layers/*.cpp
src/model/*.cpp
)
set(LIBRARY_SOURCES ${CPU_SOURCES})
@@ -32,10 +32,7 @@ set(LIBRARY_SOURCES ${CPU_SOURCES})
if(USE_CUDA)
file(GLOB_RECURSE CUDA_SOURCES
src/backends/cuda/*.cu
src/backends/cuda/utils/*.cu
src/backends/cuda/kernels/*.cu
src/backends/cuda/layers/*.cu
src/layers/*.cu # To be removed
)
set(LIBRARY_SOURCES ${LIBRARY_SOURCES} ${CUDA_SOURCES})
endif()
@@ -52,11 +49,6 @@ endif()
# Set include directories for the library
target_include_directories(${PROJECT_NAME} PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/include/utils
${CMAKE_CURRENT_SOURCE_DIR}/include/kernels
${CMAKE_CURRENT_SOURCE_DIR}/include/layers
${CMAKE_CURRENT_SOURCE_DIR}/include/model
${CMAKE_CURRENT_SOURCE_DIR}/src
)
set_property(TARGET ${PROJECT_NAME} PROPERTY CXX_STANDARD 20)

View File

@@ -1,8 +1,10 @@
#pragma once
#include <memory>
#include <optional>
#include "shape.hpp"
#include "tensor.hpp"
namespace CUDANet {
@@ -22,7 +24,14 @@ class BackendFactory {
};
class Backend {
protected:
std::optional<DType> default_dtype;
public:
virtual bool supports_dtype(DType dtype) const = 0;
virtual void set_default_dtype(DType dtype) = 0;
virtual DType get_default_dtype() const = 0;
// Memory management
virtual void* allocate(size_t bytes) = 0;
virtual void deallocate(void* ptr) = 0;

View File

@@ -1,6 +1,7 @@
#pragma once
#include <cstdio>
#include <set>
#include "backend.hpp"
#include "tensor.hpp"
@@ -29,9 +30,14 @@ namespace CUDANet::Backends {
class CUDA : public Backend {
private:
int device_id;
std::set<DType> supported_dtypes;
public:
CUDA(const BackendConfig& config);
bool supports_dtype(DType dtype) const override;
void set_default_dtype(DType dtype) override;
DType get_default_dtype() const override;
static bool is_cuda_available();
void initialize();

View File

@@ -16,6 +16,8 @@ namespace CUDANet {
*
*/
class Layer {
protected:
CUDANet::DType dtype;
public:
virtual ~Layer(){};
@@ -39,4 +41,4 @@ class Layer {
virtual size_t get_biases_size() = 0;
};
} // namespace CUDANet::Layers
} // namespace CUDANet

View File

@@ -20,12 +20,13 @@ enum ActivationType { SIGMOID, RELU, SOFTMAX, NONE };
* @brief Utility class that performs activation
*
*/
class Activation : public Layer {
class Activation : public CUDANet::Layer {
public:
Activation() = default;
Activation(ActivationType activation, const CUDANet::Shape &shape, CUDANet::Backend* backend);
Activation(ActivationType activation, const CUDANet::Shape &shape, CUDANet::DType dtype, CUDANet::Backend* backend);
~Activation() = default;
@@ -50,7 +51,7 @@ class Activation : public Layer {
private:
CUDANet::Backend* backend;
ActivationType activationType;
ActivationType activation_type;
CUDANet::Shape shape;
CUDANet::Tensor softmax_sum;

View File

@@ -8,6 +8,7 @@ namespace CUDANet::Layers {
class Add {
public:
Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::Backend* backend);
Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::DType dtype, CUDANet::Backend* backend);
~Add();
@@ -19,6 +20,8 @@ class Add {
CUDANet::Tensor output;
CUDANet::Backend *backend;
CUDANet::DType dtype;
};
} // namespace CUDANet::Layers

View File

@@ -4,7 +4,7 @@
namespace CUDANet::Layers {
class AvgPool2d : public Layer {
class AvgPool2d : public CUDANet::Layer {
public:
AvgPool2d(
CUDANet::Shape input_shape,
@@ -13,6 +13,14 @@ class AvgPool2d : public Layer {
CUDANet::Shape padding_shape,
CUDANet::Backend *backend
);
AvgPool2d(
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::DType dtype,
CUDANet::Backend *backend
);
~AvgPool2d();
@@ -50,6 +58,7 @@ class AvgPool2d : public Layer {
class AdaptiveAvgPool2d : public AvgPool2d {
public:
AdaptiveAvgPool2d(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::Backend *backend);
AdaptiveAvgPool2d(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::DType dtype, CUDANet::Backend *backend);
};
} // namespace CUDANet::Layers

View File

@@ -4,9 +4,10 @@
namespace CUDANet::Layers {
class BatchNorm2d : public Layer {
class BatchNorm2d : public CUDANet::Layer {
public:
BatchNorm2d(CUDANet::Shape input_shape, float epsilon, CUDANet::Backend *backend);
BatchNorm2d(CUDANet::Shape input_shape, float epsilon, CUDANet::DType dtype, CUDANet::Backend *backend);
~BatchNorm2d();

View File

@@ -12,6 +12,7 @@ class Concat {
public:
Concat(const CUDANet::Shape a_shape, const CUDANet::Shape b_shape, CUDANet::Backend *backend);
Concat(const CUDANet::Shape a_shape, const CUDANet::Shape b_shape, CUDANet::DType dtype, CUDANet::Backend *backend);
~Concat();
@@ -27,6 +28,8 @@ class Concat {
CUDANet::Tensor output;
CUDANet::Backend *backend;
CUDANet::DType dtype;
};
} // namespace CUDANet::Layers

View File

@@ -8,7 +8,7 @@ namespace CUDANet::Layers {
* @brief 2D convolutional layer
*
*/
class Conv2d : public Layer {
class Conv2d : public CUDANet::Layer {
public:
Conv2d(
CUDANet::Shape input_shape,
@@ -17,6 +17,14 @@ class Conv2d : public Layer {
CUDANet::Shape padding_shape,
CUDANet::Backend* backend
);
Conv2d(
CUDANet::Shape input_shape,
CUDANet::Shape kernel_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::DType dtype,
CUDANet::Backend* backend
);
~Conv2d();

View File

@@ -9,10 +9,11 @@ namespace CUDANet::Layers {
* @brief Dense (fully connected) layer
*
*/
class Dense : public Layer {
class Dense : public CUDANet::Layer {
public:
Dense(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::Backend *backend);
Dense(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::DType dtype, CUDANet::Backend *backend);
~Dense();

View File

@@ -4,7 +4,7 @@
namespace CUDANet::Layers {
class MaxPool2d : public Layer {
class MaxPool2d : public CUDANet::Layer {
public:
MaxPool2d(
CUDANet::Shape input_shape,
@@ -13,6 +13,14 @@ class MaxPool2d : public Layer {
CUDANet::Shape padding_shape,
CUDANet::Backend* backend
);
MaxPool2d(
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::DType dtype,
CUDANet::Backend* backend
);
~MaxPool2d();
CUDANet::Tensor& forward(CUDANet::Tensor &input) override;

View File

@@ -66,6 +66,12 @@ struct Shape {
__host__ bool operator!=(const Shape& other) const {
return !(*this == other);
}
__host__ __device__ bool empty() const {
return ndim == 0;
}
};
std::string format_shape(const Shape& shape) {

View File

@@ -16,11 +16,14 @@ enum class DType
// INT32, // Not implemented yet
};
size_t dtype_size(DType dtype);
class Tensor
{
public:
Tensor() = default;
Tensor(Shape shape, CUDANet::Backend* backend);
Tensor(Shape shape, DType dtype, CUDANet::Backend* backend);
Tensor(Tensor&& other) noexcept;
@@ -30,6 +33,8 @@ public:
~Tensor();
DType get_dtype();
size_t size() const;
size_t numel() const;

View File

@@ -5,12 +5,15 @@
#include <format>
#include "backend/cuda/cuda.cuh"
#include "tensor.hpp"
using namespace CUDANet::Backends;
CUDA::CUDA(const BackendConfig& config) {
device_id = config.device_id < 0 ? 0 : config.device_id;
supported_dtypes = {DType::FLOAT32};
default_dtype = DType::FLOAT32;
initialize();
}
@@ -41,6 +44,28 @@ void CUDA::initialize() {
std::printf("Using CUDA device %d: %s\n", device_id, deviceProp.name);
}
bool CUDA::supports_dtype(DType dtype) const {
return supported_dtypes.contains(dtype);
}
void CUDA::set_default_dtype(DType dtype) {
if (!supported_dtypes.contains(dtype)) {
throw std::runtime_error("Unsupported dtype");
}
default_dtype = dtype;
}
CUDANet::DType CUDA::get_default_dtype() const {
if (default_dtype) {
return default_dtype.value();
}
const_cast<CUDA*>(this)->default_dtype = DType::FLOAT32;
return DType::FLOAT32;
}
void* CUDA::allocate(size_t bytes) {
void* d_ptr = nullptr;
CUDA_CHECK(cudaMalloc(&d_ptr, bytes));

View File

@@ -1,41 +1,58 @@
#include "activation.hpp"
#include <format>
#include <stdexcept>
#include <vector>
#include "activation.hpp"
#include "tensor.hpp"
using namespace CUDANet::Layers;
Activation::Activation(ActivationType activation, const CUDANet::Shape &shape, CUDANet::Backend* backend)
: backend(backend), activationType(activation), shape(shape) {
Activation::Activation(
ActivationType activation,
const CUDANet::Shape& shape,
CUDANet::Backend* backend
)
: Activation(activation, shape, backend->get_default_dtype(), backend) {}
Activation::Activation(
ActivationType activation,
const CUDANet::Shape& shape,
CUDANet::DType dtype,
CUDANet::Backend* backend
)
: activation_type(activation),
shape(shape),
backend(backend) {
this->dtype = dtype;
if (shape.size() != 1) {
throw InvalidShapeException("input", 1, shape.size());
}
auto length = shape[0];
if (activationType == SOFTMAX) {
softmax_sum = CUDANet::Tensor({static_cast<size_t>(length)}, CUDANet::DType::FLOAT32, backend);
tensor_max = CUDANet::Tensor({static_cast<size_t>(length)}, CUDANet::DType::FLOAT32, backend);
if (activation_type == SOFTMAX) {
softmax_sum =
CUDANet::Tensor({static_cast<size_t>(length)}, dtype, backend);
tensor_max =
CUDANet::Tensor({static_cast<size_t>(length)}, dtype, backend);
}
}
CUDANet::Tensor& Activation::forward(CUDANet::Tensor &input) {
switch (activationType)
{
case ActivationType::SIGMOID:
backend->sigmoid(input);
break;
case ActivationType::RELU:
backend->relu(input);
break;
case ActivationType::SOFTMAX:
backend->softmax(input, tensor_max, softmax_sum);
break;
default:
break;
CUDANet::Tensor& Activation::forward(CUDANet::Tensor& input) {
switch (activation_type) {
case ActivationType::SIGMOID:
backend->sigmoid(input);
break;
case ActivationType::RELU:
backend->relu(input);
break;
case ActivationType::SOFTMAX:
backend->softmax(input, tensor_max, softmax_sum);
break;
default:
break;
}
return input;
@@ -57,13 +74,13 @@ size_t Activation::output_size() {
return shape[0];
}
void Activation::set_weights(void *input) {}
void Activation::set_weights(void* input) {}
size_t Activation::get_weights_size() {
return 0;
}
void Activation::set_biases(void *input) {}
void Activation::set_biases(void* input) {}
size_t Activation::get_biases_size() {
return 0;

View File

@@ -3,7 +3,11 @@
using namespace CUDANet::Layers;
Add::Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::Backend* backend) : backend(backend) {
Add::Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::Backend* backend)
: Add(a_shape, b_shape, backend->get_default_dtype(), backend) {}
Add::Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::DType dtype, CUDANet::Backend* backend)
: backend(backend), dtype(dtype) {
if (a_shape != b_shape) {
throw InvalidShapeException(
"Add requires matching dimensions", a_shape, b_shape
@@ -11,7 +15,7 @@ Add::Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::Backend* backe
}
out_shape = a_shape;
output = CUDANet::Tensor(out_shape, CUDANet::DType::FLOAT32, backend);
output = CUDANet::Tensor(out_shape, dtype, backend);
}
Add::~Add() {}

View File

@@ -11,6 +11,16 @@ AvgPool2d::AvgPool2d(
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::Backend* backend
)
: AvgPool2d(input_shape, pool_shape, stride_shape, padding_shape, backend->get_default_dtype(), backend) {}
AvgPool2d::AvgPool2d(
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::DType dtype,
CUDANet::Backend* backend
)
: in_shape(input_shape),
pool_shape(pool_shape),
@@ -33,6 +43,8 @@ AvgPool2d::AvgPool2d(
throw InvalidShapeException("padding", 2, padding_shape.size());
}
this->dtype = dtype;
out_shape = {
(in_shape[0] + 2 * padding_shape[0] - pool_shape[0]) / stride_shape[0] +
1,
@@ -43,7 +55,7 @@ AvgPool2d::AvgPool2d(
output = CUDANet::Tensor(
Shape{out_shape[0] * out_shape[1] * out_shape[2]},
CUDANet::DType::FLOAT32, backend
dtype, backend
);
}
@@ -96,6 +108,14 @@ AdaptiveAvgPool2d::AdaptiveAvgPool2d(
CUDANet::Shape input_shape,
CUDANet::Shape output_shape,
CUDANet::Backend *backend
)
: AdaptiveAvgPool2d(input_shape, output_shape, backend->get_default_dtype(), backend) {}
AdaptiveAvgPool2d::AdaptiveAvgPool2d(
CUDANet::Shape input_shape,
CUDANet::Shape output_shape,
CUDANet::DType dtype,
CUDANet::Backend *backend
)
: AvgPool2d(
input_shape,
@@ -114,12 +134,13 @@ AdaptiveAvgPool2d::AdaptiveAvgPool2d(
(input_shape[0] - (output_shape[0] - 1) * (input_shape[0] / output_shape[0]) - 1) / 2,
(input_shape[1] - (output_shape[1] - 1) * (input_shape[1] / output_shape[1]) - 1) / 2
},
dtype,
backend
) {
out_shape = output_shape;
output = CUDANet::Tensor(
Shape{out_shape[0] * out_shape[1] * out_shape[2]},
CUDANet::DType::FLOAT32, backend
dtype, backend
);
}

View File

@@ -12,6 +12,14 @@ BatchNorm2d::BatchNorm2d(
CUDANet::Shape input_shape,
float eps,
CUDANet::Backend *backend
)
: BatchNorm2d(input_shape, eps, backend->get_default_dtype(), backend) {}
BatchNorm2d::BatchNorm2d(
CUDANet::Shape input_shape,
float eps,
CUDANet::DType dtype,
CUDANet::Backend *backend
)
: in_shape(input_shape), backend(backend) {
@@ -19,22 +27,24 @@ BatchNorm2d::BatchNorm2d(
throw InvalidShapeException("input", 3, in_shape.size());
}
epsilon = CUDANet::Tensor({1}, CUDANet::DType::FLOAT32, backend);
this->dtype = dtype;
epsilon = CUDANet::Tensor({1}, dtype, backend);
epsilon.set_data<float>(&eps);
running_mean = CUDANet::Tensor({in_shape[2]}, CUDANet::DType::FLOAT32, backend);
running_mean = CUDANet::Tensor({in_shape[2]}, dtype, backend);
running_mean.zero();
running_var = CUDANet::Tensor({in_shape[2]}, CUDANet::DType::FLOAT32, backend);
running_var = CUDANet::Tensor({in_shape[2]}, dtype, backend);
running_var.fill(1);
weights = CUDANet::Tensor({in_shape[2]}, CUDANet::DType::FLOAT32, backend);
weights = CUDANet::Tensor({in_shape[2]}, dtype, backend);
weights.fill(1);
biases = CUDANet::Tensor({in_shape[2]}, CUDANet::DType::FLOAT32, backend);
biases = CUDANet::Tensor({in_shape[2]}, dtype, backend);
biases.zero();
output = CUDANet::Tensor(in_shape, CUDANet::DType::FLOAT32, backend);
output = CUDANet::Tensor(in_shape, dtype, backend);
}
BatchNorm2d::~BatchNorm2d() {}

View File

@@ -3,7 +3,10 @@
using namespace CUDANet::Layers;
Concat::Concat(const CUDANet::Shape a_shape, const CUDANet::Shape b_shape, CUDANet::Backend *backend)
: a_shape(a_shape), b_shape(b_shape), backend(backend) {
: Concat(a_shape, b_shape, backend->get_default_dtype(), backend) {}
Concat::Concat(const CUDANet::Shape a_shape, const CUDANet::Shape b_shape, CUDANet::DType dtype, CUDANet::Backend *backend)
: a_shape(a_shape), b_shape(b_shape), backend(backend), dtype(dtype) {
if (a_shape[0] != b_shape[0] || a_shape[1] != b_shape[1]) {
throw InvalidShapeException(
"Concat requires matching height and width dimensions", a_shape,
@@ -12,7 +15,7 @@ Concat::Concat(const CUDANet::Shape a_shape, const CUDANet::Shape b_shape, CUDAN
}
out_shape = {a_shape[0], a_shape[1], a_shape[2] + b_shape[2]};
output = CUDANet::Tensor(out_shape, CUDANet::DType::FLOAT32, backend);
output = CUDANet::Tensor(out_shape, dtype, backend);
}
Concat::~Concat() {}

View File

@@ -14,6 +14,16 @@ Conv2d::Conv2d(
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::Backend* backend
)
: Conv2d(input_shape, kernel_shape, stride_shape, padding_shape, backend->get_default_dtype(), backend) {}
Conv2d::Conv2d(
CUDANet::Shape input_shape,
CUDANet::Shape kernel_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::DType dtype,
CUDANet::Backend* backend
)
: in_shape(input_shape),
kernel_shape(kernel_shape),
@@ -36,6 +46,8 @@ Conv2d::Conv2d(
throw InvalidShapeException("padding", 3, padding_shape.size());
}
this->dtype = dtype;
out_shape = {
(in_shape[0] - kernel_shape[0] + 2 * padding_shape[0]) /
stride_shape[0] +
@@ -48,17 +60,17 @@ Conv2d::Conv2d(
output = CUDANet::Tensor(
Shape{out_shape[0], out_shape[1], out_shape[2]},
CUDANet::DType::FLOAT32, backend
dtype, backend
);
weights = CUDANet::Tensor(
Shape{
kernel_shape[0], kernel_shape[1], kernel_shape[2], in_shape[2]
},
CUDANet::DType::FLOAT32, backend
dtype, backend
);
biases = CUDANet::Tensor(
Shape{kernel_shape[2]}, CUDANet::DType::FLOAT32, backend
Shape{kernel_shape[2]}, dtype, backend
);
weights.zero();

View File

@@ -6,6 +6,9 @@
using namespace CUDANet::Layers;
Dense::Dense(CUDANet::Shape in_shape, CUDANet::Shape out_shape, CUDANet::Backend* backend)
: Dense(in_shape, out_shape, backend->get_default_dtype(), backend) {}
Dense::Dense(CUDANet::Shape in_shape, CUDANet::Shape out_shape, CUDANet::DType dtype, CUDANet::Backend* backend)
: backend(backend),
in_shape(in_shape),
out_shape(out_shape) {
@@ -18,9 +21,11 @@ Dense::Dense(CUDANet::Shape in_shape, CUDANet::Shape out_shape, CUDANet::Backend
throw InvalidShapeException("output", 1, out_shape.size());
}
weights = CUDANet::Tensor(Shape{out_shape[0], in_shape[0]}, CUDANet::DType::FLOAT32, backend);
biases = CUDANet::Tensor(Shape{out_shape[0]}, CUDANet::DType::FLOAT32, backend);
output = CUDANet::Tensor(Shape{out_shape[0]}, CUDANet::DType::FLOAT32, backend);
this->dtype = dtype;
weights = CUDANet::Tensor(Shape{out_shape[0], in_shape[0]}, dtype, backend);
biases = CUDANet::Tensor(Shape{out_shape[0]}, dtype, backend);
output = CUDANet::Tensor(Shape{out_shape[0]}, dtype, backend);
weights.zero();
biases.zero();

View File

@@ -10,6 +10,16 @@ MaxPool2d::MaxPool2d(
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::Backend* backend
)
: MaxPool2d(input_shape, pool_shape, stride_shape, padding_shape, backend->get_default_dtype(), backend) {}
MaxPool2d::MaxPool2d(
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::DType dtype,
CUDANet::Backend* backend
)
: in_shape(input_shape),
pool_shape(pool_shape),
@@ -32,6 +42,8 @@ MaxPool2d::MaxPool2d(
throw InvalidShapeException("padding", 2, padding_shape.size());
}
this->dtype = dtype;
out_shape = {
(in_shape[0] + 2 * padding_shape[0] - pool_shape[0]) / stride_shape[0] +
1,
@@ -42,7 +54,7 @@ MaxPool2d::MaxPool2d(
output = CUDANet::Tensor(
Shape{out_shape[0] * out_shape[1] * out_shape[2]},
CUDANet::DType::FLOAT32, backend
dtype, backend
);
}

View File

@@ -1,20 +1,27 @@
#include <stdexcept>
#include "tensor.hpp"
#include <stdexcept>
using namespace CUDANet;
Tensor::Tensor(Shape shape, CUDANet::Backend* backend)
: Tensor(shape, backend->get_default_dtype(), backend) {}
Tensor::Tensor(Shape shape, DType dtype, Backend* backend)
: shape(shape), dtype(dtype), backend(backend), d_ptr(nullptr) {
if (shape.empty()) {
throw std::runtime_error("Tensor shape cannot be empty");
}
// Check if backend supports DType
if (!backend->supports_dtype(dtype)) {
throw std::runtime_error("Unsupported DType");
}
// Count total elements
size_t count = 1;
for (const auto& dim : shape) {
count *= dim;
for (size_t i = 0; i < shape.size(); ++i) {
count *= shape[i];
}
total_elms = count;
@@ -39,9 +46,8 @@ Tensor::Tensor(Tensor&& other) noexcept
total_elms(other.total_elms),
total_size(other.total_size),
backend(other.backend),
d_ptr(other.d_ptr)
{
other.d_ptr = nullptr;
d_ptr(other.d_ptr) {
other.d_ptr = nullptr;
other.backend = nullptr;
}
@@ -51,17 +57,17 @@ Tensor& Tensor::operator=(Tensor&& other) noexcept {
if (d_ptr != nullptr && backend != nullptr) {
backend->deallocate(d_ptr);
}
// Steal other's resources
shape = std::move(other.shape);
dtype = other.dtype;
shape = std::move(other.shape);
dtype = other.dtype;
total_elms = other.total_elms;
total_size = other.total_size;
backend = other.backend;
d_ptr = other.d_ptr;
backend = other.backend;
d_ptr = other.d_ptr;
// Leave other in valid but empty state
other.d_ptr = nullptr;
other.d_ptr = nullptr;
other.backend = nullptr;
}
return *this;