Compare commits

..

3 Commits

Author SHA1 Message Date
13d3d38b68 Add dtype parameter to layer constructors 2025-11-26 00:19:33 +01:00
84153ac49c Add default dtype to backend 2025-11-25 23:42:19 +01:00
ad079560ff Update CMakeLists.txt 2025-11-25 19:08:55 +01:00
24 changed files with 243 additions and 74 deletions

View File

@@ -23,8 +23,8 @@ endif()
file(GLOB_RECURSE CPU_SOURCES file(GLOB_RECURSE CPU_SOURCES
src/*.cpp
src/layers/*.cpp src/layers/*.cpp
src/model/*.cpp
) )
set(LIBRARY_SOURCES ${CPU_SOURCES}) set(LIBRARY_SOURCES ${CPU_SOURCES})
@@ -32,10 +32,7 @@ set(LIBRARY_SOURCES ${CPU_SOURCES})
if(USE_CUDA) if(USE_CUDA)
file(GLOB_RECURSE CUDA_SOURCES file(GLOB_RECURSE CUDA_SOURCES
src/backends/cuda/*.cu src/backends/cuda/*.cu
src/backends/cuda/utils/*.cu
src/backends/cuda/kernels/*.cu src/backends/cuda/kernels/*.cu
src/backends/cuda/layers/*.cu
src/layers/*.cu # To be removed
) )
set(LIBRARY_SOURCES ${LIBRARY_SOURCES} ${CUDA_SOURCES}) set(LIBRARY_SOURCES ${LIBRARY_SOURCES} ${CUDA_SOURCES})
endif() endif()
@@ -52,11 +49,6 @@ endif()
# Set include directories for the library # Set include directories for the library
target_include_directories(${PROJECT_NAME} PUBLIC target_include_directories(${PROJECT_NAME} PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}/include ${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/include/utils
${CMAKE_CURRENT_SOURCE_DIR}/include/kernels
${CMAKE_CURRENT_SOURCE_DIR}/include/layers
${CMAKE_CURRENT_SOURCE_DIR}/include/model
${CMAKE_CURRENT_SOURCE_DIR}/src
) )
set_property(TARGET ${PROJECT_NAME} PROPERTY CXX_STANDARD 20) set_property(TARGET ${PROJECT_NAME} PROPERTY CXX_STANDARD 20)

View File

@@ -1,8 +1,10 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <optional>
#include "shape.hpp" #include "shape.hpp"
#include "tensor.hpp"
namespace CUDANet { namespace CUDANet {
@@ -22,7 +24,14 @@ class BackendFactory {
}; };
class Backend { class Backend {
protected:
std::optional<DType> default_dtype;
public: public:
virtual bool supports_dtype(DType dtype) const = 0;
virtual void set_default_dtype(DType dtype) = 0;
virtual DType get_default_dtype() const = 0;
// Memory management // Memory management
virtual void* allocate(size_t bytes) = 0; virtual void* allocate(size_t bytes) = 0;
virtual void deallocate(void* ptr) = 0; virtual void deallocate(void* ptr) = 0;

View File

@@ -1,6 +1,7 @@
#pragma once #pragma once
#include <cstdio> #include <cstdio>
#include <set>
#include "backend.hpp" #include "backend.hpp"
#include "tensor.hpp" #include "tensor.hpp"
@@ -29,9 +30,14 @@ namespace CUDANet::Backends {
class CUDA : public Backend { class CUDA : public Backend {
private: private:
int device_id; int device_id;
std::set<DType> supported_dtypes;
public: public:
CUDA(const BackendConfig& config); CUDA(const BackendConfig& config);
bool supports_dtype(DType dtype) const override;
void set_default_dtype(DType dtype) override;
DType get_default_dtype() const override;
static bool is_cuda_available(); static bool is_cuda_available();
void initialize(); void initialize();

View File

@@ -16,6 +16,8 @@ namespace CUDANet {
* *
*/ */
class Layer { class Layer {
protected:
CUDANet::DType dtype;
public: public:
virtual ~Layer(){}; virtual ~Layer(){};
@@ -39,4 +41,4 @@ class Layer {
virtual size_t get_biases_size() = 0; virtual size_t get_biases_size() = 0;
}; };
} // namespace CUDANet::Layers } // namespace CUDANet

View File

@@ -20,12 +20,13 @@ enum ActivationType { SIGMOID, RELU, SOFTMAX, NONE };
* @brief Utility class that performs activation * @brief Utility class that performs activation
* *
*/ */
class Activation : public Layer { class Activation : public CUDANet::Layer {
public: public:
Activation() = default; Activation() = default;
Activation(ActivationType activation, const CUDANet::Shape &shape, CUDANet::Backend* backend); Activation(ActivationType activation, const CUDANet::Shape &shape, CUDANet::Backend* backend);
Activation(ActivationType activation, const CUDANet::Shape &shape, CUDANet::DType dtype, CUDANet::Backend* backend);
~Activation() = default; ~Activation() = default;
@@ -50,7 +51,7 @@ class Activation : public Layer {
private: private:
CUDANet::Backend* backend; CUDANet::Backend* backend;
ActivationType activationType; ActivationType activation_type;
CUDANet::Shape shape; CUDANet::Shape shape;
CUDANet::Tensor softmax_sum; CUDANet::Tensor softmax_sum;

View File

@@ -8,6 +8,7 @@ namespace CUDANet::Layers {
class Add { class Add {
public: public:
Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::Backend* backend); Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::Backend* backend);
Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::DType dtype, CUDANet::Backend* backend);
~Add(); ~Add();
@@ -19,6 +20,8 @@ class Add {
CUDANet::Tensor output; CUDANet::Tensor output;
CUDANet::Backend *backend; CUDANet::Backend *backend;
CUDANet::DType dtype;
}; };
} // namespace CUDANet::Layers } // namespace CUDANet::Layers

View File

@@ -4,7 +4,7 @@
namespace CUDANet::Layers { namespace CUDANet::Layers {
class AvgPool2d : public Layer { class AvgPool2d : public CUDANet::Layer {
public: public:
AvgPool2d( AvgPool2d(
CUDANet::Shape input_shape, CUDANet::Shape input_shape,
@@ -13,6 +13,14 @@ class AvgPool2d : public Layer {
CUDANet::Shape padding_shape, CUDANet::Shape padding_shape,
CUDANet::Backend *backend CUDANet::Backend *backend
); );
AvgPool2d(
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::DType dtype,
CUDANet::Backend *backend
);
~AvgPool2d(); ~AvgPool2d();
@@ -50,6 +58,7 @@ class AvgPool2d : public Layer {
class AdaptiveAvgPool2d : public AvgPool2d { class AdaptiveAvgPool2d : public AvgPool2d {
public: public:
AdaptiveAvgPool2d(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::Backend *backend); AdaptiveAvgPool2d(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::Backend *backend);
AdaptiveAvgPool2d(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::DType dtype, CUDANet::Backend *backend);
}; };
} // namespace CUDANet::Layers } // namespace CUDANet::Layers

View File

@@ -4,9 +4,10 @@
namespace CUDANet::Layers { namespace CUDANet::Layers {
class BatchNorm2d : public Layer { class BatchNorm2d : public CUDANet::Layer {
public: public:
BatchNorm2d(CUDANet::Shape input_shape, float epsilon, CUDANet::Backend *backend); BatchNorm2d(CUDANet::Shape input_shape, float epsilon, CUDANet::Backend *backend);
BatchNorm2d(CUDANet::Shape input_shape, float epsilon, CUDANet::DType dtype, CUDANet::Backend *backend);
~BatchNorm2d(); ~BatchNorm2d();

View File

@@ -12,6 +12,7 @@ class Concat {
public: public:
Concat(const CUDANet::Shape a_shape, const CUDANet::Shape b_shape, CUDANet::Backend *backend); Concat(const CUDANet::Shape a_shape, const CUDANet::Shape b_shape, CUDANet::Backend *backend);
Concat(const CUDANet::Shape a_shape, const CUDANet::Shape b_shape, CUDANet::DType dtype, CUDANet::Backend *backend);
~Concat(); ~Concat();
@@ -27,6 +28,8 @@ class Concat {
CUDANet::Tensor output; CUDANet::Tensor output;
CUDANet::Backend *backend; CUDANet::Backend *backend;
CUDANet::DType dtype;
}; };
} // namespace CUDANet::Layers } // namespace CUDANet::Layers

View File

@@ -8,7 +8,7 @@ namespace CUDANet::Layers {
* @brief 2D convolutional layer * @brief 2D convolutional layer
* *
*/ */
class Conv2d : public Layer { class Conv2d : public CUDANet::Layer {
public: public:
Conv2d( Conv2d(
CUDANet::Shape input_shape, CUDANet::Shape input_shape,
@@ -17,6 +17,14 @@ class Conv2d : public Layer {
CUDANet::Shape padding_shape, CUDANet::Shape padding_shape,
CUDANet::Backend* backend CUDANet::Backend* backend
); );
Conv2d(
CUDANet::Shape input_shape,
CUDANet::Shape kernel_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::DType dtype,
CUDANet::Backend* backend
);
~Conv2d(); ~Conv2d();

View File

@@ -9,10 +9,11 @@ namespace CUDANet::Layers {
* @brief Dense (fully connected) layer * @brief Dense (fully connected) layer
* *
*/ */
class Dense : public Layer { class Dense : public CUDANet::Layer {
public: public:
Dense(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::Backend *backend); Dense(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::Backend *backend);
Dense(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::DType dtype, CUDANet::Backend *backend);
~Dense(); ~Dense();

View File

@@ -4,7 +4,7 @@
namespace CUDANet::Layers { namespace CUDANet::Layers {
class MaxPool2d : public Layer { class MaxPool2d : public CUDANet::Layer {
public: public:
MaxPool2d( MaxPool2d(
CUDANet::Shape input_shape, CUDANet::Shape input_shape,
@@ -13,6 +13,14 @@ class MaxPool2d : public Layer {
CUDANet::Shape padding_shape, CUDANet::Shape padding_shape,
CUDANet::Backend* backend CUDANet::Backend* backend
); );
MaxPool2d(
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::DType dtype,
CUDANet::Backend* backend
);
~MaxPool2d(); ~MaxPool2d();
CUDANet::Tensor& forward(CUDANet::Tensor &input) override; CUDANet::Tensor& forward(CUDANet::Tensor &input) override;

View File

@@ -66,6 +66,12 @@ struct Shape {
__host__ bool operator!=(const Shape& other) const { __host__ bool operator!=(const Shape& other) const {
return !(*this == other); return !(*this == other);
} }
__host__ __device__ bool empty() const {
return ndim == 0;
}
}; };
std::string format_shape(const Shape& shape) { std::string format_shape(const Shape& shape) {

View File

@@ -16,11 +16,14 @@ enum class DType
// INT32, // Not implemented yet // INT32, // Not implemented yet
}; };
size_t dtype_size(DType dtype);
class Tensor class Tensor
{ {
public: public:
Tensor() = default; Tensor() = default;
Tensor(Shape shape, CUDANet::Backend* backend);
Tensor(Shape shape, DType dtype, CUDANet::Backend* backend); Tensor(Shape shape, DType dtype, CUDANet::Backend* backend);
Tensor(Tensor&& other) noexcept; Tensor(Tensor&& other) noexcept;
@@ -30,6 +33,8 @@ public:
~Tensor(); ~Tensor();
DType get_dtype();
size_t size() const; size_t size() const;
size_t numel() const; size_t numel() const;

View File

@@ -5,12 +5,15 @@
#include <format> #include <format>
#include "backend/cuda/cuda.cuh" #include "backend/cuda/cuda.cuh"
#include "tensor.hpp"
using namespace CUDANet::Backends; using namespace CUDANet::Backends;
CUDA::CUDA(const BackendConfig& config) { CUDA::CUDA(const BackendConfig& config) {
device_id = config.device_id < 0 ? 0 : config.device_id; device_id = config.device_id < 0 ? 0 : config.device_id;
supported_dtypes = {DType::FLOAT32};
default_dtype = DType::FLOAT32;
initialize(); initialize();
} }
@@ -41,6 +44,28 @@ void CUDA::initialize() {
std::printf("Using CUDA device %d: %s\n", device_id, deviceProp.name); std::printf("Using CUDA device %d: %s\n", device_id, deviceProp.name);
} }
bool CUDA::supports_dtype(DType dtype) const {
return supported_dtypes.contains(dtype);
}
void CUDA::set_default_dtype(DType dtype) {
if (!supported_dtypes.contains(dtype)) {
throw std::runtime_error("Unsupported dtype");
}
default_dtype = dtype;
}
CUDANet::DType CUDA::get_default_dtype() const {
if (default_dtype) {
return default_dtype.value();
}
const_cast<CUDA*>(this)->default_dtype = DType::FLOAT32;
return DType::FLOAT32;
}
void* CUDA::allocate(size_t bytes) { void* CUDA::allocate(size_t bytes) {
void* d_ptr = nullptr; void* d_ptr = nullptr;
CUDA_CHECK(cudaMalloc(&d_ptr, bytes)); CUDA_CHECK(cudaMalloc(&d_ptr, bytes));

View File

@@ -1,14 +1,30 @@
#include "activation.hpp"
#include <format> #include <format>
#include <stdexcept> #include <stdexcept>
#include <vector> #include <vector>
#include "activation.hpp"
#include "tensor.hpp" #include "tensor.hpp"
using namespace CUDANet::Layers; using namespace CUDANet::Layers;
Activation::Activation(ActivationType activation, const CUDANet::Shape &shape, CUDANet::Backend* backend) Activation::Activation(
: backend(backend), activationType(activation), shape(shape) { ActivationType activation,
const CUDANet::Shape& shape,
CUDANet::Backend* backend
)
: Activation(activation, shape, backend->get_default_dtype(), backend) {}
Activation::Activation(
ActivationType activation,
const CUDANet::Shape& shape,
CUDANet::DType dtype,
CUDANet::Backend* backend
)
: activation_type(activation),
shape(shape),
backend(backend) {
this->dtype = dtype;
if (shape.size() != 1) { if (shape.size() != 1) {
throw InvalidShapeException("input", 1, shape.size()); throw InvalidShapeException("input", 1, shape.size());
@@ -16,15 +32,16 @@ Activation::Activation(ActivationType activation, const CUDANet::Shape &shape, C
auto length = shape[0]; auto length = shape[0];
if (activationType == SOFTMAX) { if (activation_type == SOFTMAX) {
softmax_sum = CUDANet::Tensor({static_cast<size_t>(length)}, CUDANet::DType::FLOAT32, backend); softmax_sum =
tensor_max = CUDANet::Tensor({static_cast<size_t>(length)}, CUDANet::DType::FLOAT32, backend); CUDANet::Tensor({static_cast<size_t>(length)}, dtype, backend);
tensor_max =
CUDANet::Tensor({static_cast<size_t>(length)}, dtype, backend);
} }
} }
CUDANet::Tensor& Activation::forward(CUDANet::Tensor& input) { CUDANet::Tensor& Activation::forward(CUDANet::Tensor& input) {
switch (activationType) switch (activation_type) {
{
case ActivationType::SIGMOID: case ActivationType::SIGMOID:
backend->sigmoid(input); backend->sigmoid(input);
break; break;

View File

@@ -3,7 +3,11 @@
using namespace CUDANet::Layers; using namespace CUDANet::Layers;
Add::Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::Backend* backend) : backend(backend) { Add::Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::Backend* backend)
: Add(a_shape, b_shape, backend->get_default_dtype(), backend) {}
Add::Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::DType dtype, CUDANet::Backend* backend)
: backend(backend), dtype(dtype) {
if (a_shape != b_shape) { if (a_shape != b_shape) {
throw InvalidShapeException( throw InvalidShapeException(
"Add requires matching dimensions", a_shape, b_shape "Add requires matching dimensions", a_shape, b_shape
@@ -11,7 +15,7 @@ Add::Add(CUDANet::Shape a_shape, CUDANet::Shape b_shape, CUDANet::Backend* backe
} }
out_shape = a_shape; out_shape = a_shape;
output = CUDANet::Tensor(out_shape, CUDANet::DType::FLOAT32, backend); output = CUDANet::Tensor(out_shape, dtype, backend);
} }
Add::~Add() {} Add::~Add() {}

View File

@@ -11,6 +11,16 @@ AvgPool2d::AvgPool2d(
CUDANet::Shape stride_shape, CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape, CUDANet::Shape padding_shape,
CUDANet::Backend* backend CUDANet::Backend* backend
)
: AvgPool2d(input_shape, pool_shape, stride_shape, padding_shape, backend->get_default_dtype(), backend) {}
AvgPool2d::AvgPool2d(
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::DType dtype,
CUDANet::Backend* backend
) )
: in_shape(input_shape), : in_shape(input_shape),
pool_shape(pool_shape), pool_shape(pool_shape),
@@ -33,6 +43,8 @@ AvgPool2d::AvgPool2d(
throw InvalidShapeException("padding", 2, padding_shape.size()); throw InvalidShapeException("padding", 2, padding_shape.size());
} }
this->dtype = dtype;
out_shape = { out_shape = {
(in_shape[0] + 2 * padding_shape[0] - pool_shape[0]) / stride_shape[0] + (in_shape[0] + 2 * padding_shape[0] - pool_shape[0]) / stride_shape[0] +
1, 1,
@@ -43,7 +55,7 @@ AvgPool2d::AvgPool2d(
output = CUDANet::Tensor( output = CUDANet::Tensor(
Shape{out_shape[0] * out_shape[1] * out_shape[2]}, Shape{out_shape[0] * out_shape[1] * out_shape[2]},
CUDANet::DType::FLOAT32, backend dtype, backend
); );
} }
@@ -96,6 +108,14 @@ AdaptiveAvgPool2d::AdaptiveAvgPool2d(
CUDANet::Shape input_shape, CUDANet::Shape input_shape,
CUDANet::Shape output_shape, CUDANet::Shape output_shape,
CUDANet::Backend *backend CUDANet::Backend *backend
)
: AdaptiveAvgPool2d(input_shape, output_shape, backend->get_default_dtype(), backend) {}
AdaptiveAvgPool2d::AdaptiveAvgPool2d(
CUDANet::Shape input_shape,
CUDANet::Shape output_shape,
CUDANet::DType dtype,
CUDANet::Backend *backend
) )
: AvgPool2d( : AvgPool2d(
input_shape, input_shape,
@@ -114,12 +134,13 @@ AdaptiveAvgPool2d::AdaptiveAvgPool2d(
(input_shape[0] - (output_shape[0] - 1) * (input_shape[0] / output_shape[0]) - 1) / 2, (input_shape[0] - (output_shape[0] - 1) * (input_shape[0] / output_shape[0]) - 1) / 2,
(input_shape[1] - (output_shape[1] - 1) * (input_shape[1] / output_shape[1]) - 1) / 2 (input_shape[1] - (output_shape[1] - 1) * (input_shape[1] / output_shape[1]) - 1) / 2
}, },
dtype,
backend backend
) { ) {
out_shape = output_shape; out_shape = output_shape;
output = CUDANet::Tensor( output = CUDANet::Tensor(
Shape{out_shape[0] * out_shape[1] * out_shape[2]}, Shape{out_shape[0] * out_shape[1] * out_shape[2]},
CUDANet::DType::FLOAT32, backend dtype, backend
); );
} }

View File

@@ -12,6 +12,14 @@ BatchNorm2d::BatchNorm2d(
CUDANet::Shape input_shape, CUDANet::Shape input_shape,
float eps, float eps,
CUDANet::Backend *backend CUDANet::Backend *backend
)
: BatchNorm2d(input_shape, eps, backend->get_default_dtype(), backend) {}
BatchNorm2d::BatchNorm2d(
CUDANet::Shape input_shape,
float eps,
CUDANet::DType dtype,
CUDANet::Backend *backend
) )
: in_shape(input_shape), backend(backend) { : in_shape(input_shape), backend(backend) {
@@ -19,22 +27,24 @@ BatchNorm2d::BatchNorm2d(
throw InvalidShapeException("input", 3, in_shape.size()); throw InvalidShapeException("input", 3, in_shape.size());
} }
epsilon = CUDANet::Tensor({1}, CUDANet::DType::FLOAT32, backend); this->dtype = dtype;
epsilon = CUDANet::Tensor({1}, dtype, backend);
epsilon.set_data<float>(&eps); epsilon.set_data<float>(&eps);
running_mean = CUDANet::Tensor({in_shape[2]}, CUDANet::DType::FLOAT32, backend); running_mean = CUDANet::Tensor({in_shape[2]}, dtype, backend);
running_mean.zero(); running_mean.zero();
running_var = CUDANet::Tensor({in_shape[2]}, CUDANet::DType::FLOAT32, backend); running_var = CUDANet::Tensor({in_shape[2]}, dtype, backend);
running_var.fill(1); running_var.fill(1);
weights = CUDANet::Tensor({in_shape[2]}, CUDANet::DType::FLOAT32, backend); weights = CUDANet::Tensor({in_shape[2]}, dtype, backend);
weights.fill(1); weights.fill(1);
biases = CUDANet::Tensor({in_shape[2]}, CUDANet::DType::FLOAT32, backend); biases = CUDANet::Tensor({in_shape[2]}, dtype, backend);
biases.zero(); biases.zero();
output = CUDANet::Tensor(in_shape, CUDANet::DType::FLOAT32, backend); output = CUDANet::Tensor(in_shape, dtype, backend);
} }
BatchNorm2d::~BatchNorm2d() {} BatchNorm2d::~BatchNorm2d() {}

View File

@@ -3,7 +3,10 @@
using namespace CUDANet::Layers; using namespace CUDANet::Layers;
Concat::Concat(const CUDANet::Shape a_shape, const CUDANet::Shape b_shape, CUDANet::Backend *backend) Concat::Concat(const CUDANet::Shape a_shape, const CUDANet::Shape b_shape, CUDANet::Backend *backend)
: a_shape(a_shape), b_shape(b_shape), backend(backend) { : Concat(a_shape, b_shape, backend->get_default_dtype(), backend) {}
Concat::Concat(const CUDANet::Shape a_shape, const CUDANet::Shape b_shape, CUDANet::DType dtype, CUDANet::Backend *backend)
: a_shape(a_shape), b_shape(b_shape), backend(backend), dtype(dtype) {
if (a_shape[0] != b_shape[0] || a_shape[1] != b_shape[1]) { if (a_shape[0] != b_shape[0] || a_shape[1] != b_shape[1]) {
throw InvalidShapeException( throw InvalidShapeException(
"Concat requires matching height and width dimensions", a_shape, "Concat requires matching height and width dimensions", a_shape,
@@ -12,7 +15,7 @@ Concat::Concat(const CUDANet::Shape a_shape, const CUDANet::Shape b_shape, CUDAN
} }
out_shape = {a_shape[0], a_shape[1], a_shape[2] + b_shape[2]}; out_shape = {a_shape[0], a_shape[1], a_shape[2] + b_shape[2]};
output = CUDANet::Tensor(out_shape, CUDANet::DType::FLOAT32, backend); output = CUDANet::Tensor(out_shape, dtype, backend);
} }
Concat::~Concat() {} Concat::~Concat() {}

View File

@@ -14,6 +14,16 @@ Conv2d::Conv2d(
CUDANet::Shape stride_shape, CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape, CUDANet::Shape padding_shape,
CUDANet::Backend* backend CUDANet::Backend* backend
)
: Conv2d(input_shape, kernel_shape, stride_shape, padding_shape, backend->get_default_dtype(), backend) {}
Conv2d::Conv2d(
CUDANet::Shape input_shape,
CUDANet::Shape kernel_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::DType dtype,
CUDANet::Backend* backend
) )
: in_shape(input_shape), : in_shape(input_shape),
kernel_shape(kernel_shape), kernel_shape(kernel_shape),
@@ -36,6 +46,8 @@ Conv2d::Conv2d(
throw InvalidShapeException("padding", 3, padding_shape.size()); throw InvalidShapeException("padding", 3, padding_shape.size());
} }
this->dtype = dtype;
out_shape = { out_shape = {
(in_shape[0] - kernel_shape[0] + 2 * padding_shape[0]) / (in_shape[0] - kernel_shape[0] + 2 * padding_shape[0]) /
stride_shape[0] + stride_shape[0] +
@@ -48,17 +60,17 @@ Conv2d::Conv2d(
output = CUDANet::Tensor( output = CUDANet::Tensor(
Shape{out_shape[0], out_shape[1], out_shape[2]}, Shape{out_shape[0], out_shape[1], out_shape[2]},
CUDANet::DType::FLOAT32, backend dtype, backend
); );
weights = CUDANet::Tensor( weights = CUDANet::Tensor(
Shape{ Shape{
kernel_shape[0], kernel_shape[1], kernel_shape[2], in_shape[2] kernel_shape[0], kernel_shape[1], kernel_shape[2], in_shape[2]
}, },
CUDANet::DType::FLOAT32, backend dtype, backend
); );
biases = CUDANet::Tensor( biases = CUDANet::Tensor(
Shape{kernel_shape[2]}, CUDANet::DType::FLOAT32, backend Shape{kernel_shape[2]}, dtype, backend
); );
weights.zero(); weights.zero();

View File

@@ -6,6 +6,9 @@
using namespace CUDANet::Layers; using namespace CUDANet::Layers;
Dense::Dense(CUDANet::Shape in_shape, CUDANet::Shape out_shape, CUDANet::Backend* backend) Dense::Dense(CUDANet::Shape in_shape, CUDANet::Shape out_shape, CUDANet::Backend* backend)
: Dense(in_shape, out_shape, backend->get_default_dtype(), backend) {}
Dense::Dense(CUDANet::Shape in_shape, CUDANet::Shape out_shape, CUDANet::DType dtype, CUDANet::Backend* backend)
: backend(backend), : backend(backend),
in_shape(in_shape), in_shape(in_shape),
out_shape(out_shape) { out_shape(out_shape) {
@@ -18,9 +21,11 @@ Dense::Dense(CUDANet::Shape in_shape, CUDANet::Shape out_shape, CUDANet::Backend
throw InvalidShapeException("output", 1, out_shape.size()); throw InvalidShapeException("output", 1, out_shape.size());
} }
weights = CUDANet::Tensor(Shape{out_shape[0], in_shape[0]}, CUDANet::DType::FLOAT32, backend); this->dtype = dtype;
biases = CUDANet::Tensor(Shape{out_shape[0]}, CUDANet::DType::FLOAT32, backend);
output = CUDANet::Tensor(Shape{out_shape[0]}, CUDANet::DType::FLOAT32, backend); weights = CUDANet::Tensor(Shape{out_shape[0], in_shape[0]}, dtype, backend);
biases = CUDANet::Tensor(Shape{out_shape[0]}, dtype, backend);
output = CUDANet::Tensor(Shape{out_shape[0]}, dtype, backend);
weights.zero(); weights.zero();
biases.zero(); biases.zero();

View File

@@ -10,6 +10,16 @@ MaxPool2d::MaxPool2d(
CUDANet::Shape stride_shape, CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape, CUDANet::Shape padding_shape,
CUDANet::Backend* backend CUDANet::Backend* backend
)
: MaxPool2d(input_shape, pool_shape, stride_shape, padding_shape, backend->get_default_dtype(), backend) {}
MaxPool2d::MaxPool2d(
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::DType dtype,
CUDANet::Backend* backend
) )
: in_shape(input_shape), : in_shape(input_shape),
pool_shape(pool_shape), pool_shape(pool_shape),
@@ -32,6 +42,8 @@ MaxPool2d::MaxPool2d(
throw InvalidShapeException("padding", 2, padding_shape.size()); throw InvalidShapeException("padding", 2, padding_shape.size());
} }
this->dtype = dtype;
out_shape = { out_shape = {
(in_shape[0] + 2 * padding_shape[0] - pool_shape[0]) / stride_shape[0] + (in_shape[0] + 2 * padding_shape[0] - pool_shape[0]) / stride_shape[0] +
1, 1,
@@ -42,7 +54,7 @@ MaxPool2d::MaxPool2d(
output = CUDANet::Tensor( output = CUDANet::Tensor(
Shape{out_shape[0] * out_shape[1] * out_shape[2]}, Shape{out_shape[0] * out_shape[1] * out_shape[2]},
CUDANet::DType::FLOAT32, backend dtype, backend
); );
} }

View File

@@ -1,20 +1,27 @@
#include <stdexcept>
#include "tensor.hpp" #include "tensor.hpp"
#include <stdexcept>
using namespace CUDANet; using namespace CUDANet;
Tensor::Tensor(Shape shape, CUDANet::Backend* backend)
: Tensor(shape, backend->get_default_dtype(), backend) {}
Tensor::Tensor(Shape shape, DType dtype, Backend* backend) Tensor::Tensor(Shape shape, DType dtype, Backend* backend)
: shape(shape), dtype(dtype), backend(backend), d_ptr(nullptr) { : shape(shape), dtype(dtype), backend(backend), d_ptr(nullptr) {
if (shape.empty()) { if (shape.empty()) {
throw std::runtime_error("Tensor shape cannot be empty"); throw std::runtime_error("Tensor shape cannot be empty");
} }
// Check if backend supports DType
if (!backend->supports_dtype(dtype)) {
throw std::runtime_error("Unsupported DType");
}
// Count total elements // Count total elements
size_t count = 1; size_t count = 1;
for (const auto& dim : shape) { for (size_t i = 0; i < shape.size(); ++i) {
count *= dim; count *= shape[i];
} }
total_elms = count; total_elms = count;
@@ -39,8 +46,7 @@ Tensor::Tensor(Tensor&& other) noexcept
total_elms(other.total_elms), total_elms(other.total_elms),
total_size(other.total_size), total_size(other.total_size),
backend(other.backend), backend(other.backend),
d_ptr(other.d_ptr) d_ptr(other.d_ptr) {
{
other.d_ptr = nullptr; other.d_ptr = nullptr;
other.backend = nullptr; other.backend = nullptr;
} }