From 1102aef293e6e98cfa6c5159341089d065e2e8a2 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sun, 23 Nov 2025 19:21:06 +0100 Subject: [PATCH] Implement custom Shape struct with __device__ support --- include/shape.hpp | 62 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/include/shape.hpp b/include/shape.hpp index 5f0b1ae..e9db6ee 100644 --- a/include/shape.hpp +++ b/include/shape.hpp @@ -1,11 +1,71 @@ #pragma once +#ifndef __host__ +#define __host__ +#endif + +#ifndef __device__ +#define __device__ +#endif + #include #include namespace CUDANet { -typedef std::vector Shape; +struct Shape { + static constexpr size_t MAX_DIMS = 8; + + size_t dims[MAX_DIMS]; + size_t ndim; + + __host__ __device__ Shape() : ndim(0) { + for (int i = 0; i < MAX_DIMS; i++) dims[i] = 0; + } + + __host__ Shape(std::initializer_list list) : ndim(list.size()) { + if (ndim > MAX_DIMS) { + throw std::runtime_error("Too many dimensions"); + } + size_t i = 0; + for (auto val : list) { + dims[i++] = val; + } + for (; i < MAX_DIMS; i++) dims[i] = 0; + } + + __host__ Shape(const std::vector& vec) : ndim(vec.size()) { + if (ndim > MAX_DIMS) { + throw std::runtime_error("Too many dimensions"); + } + for (size_t i = 0; i < ndim; i++) { + dims[i] = vec[i]; + } + for (size_t i = ndim; i < MAX_DIMS; i++) dims[i] = 0; + } + + __host__ __device__ size_t operator[](size_t idx) const { + return dims[idx]; + } + + __host__ __device__ size_t& operator[](size_t idx) { + return dims[idx]; + } + + __host__ __device__ size_t size() const { return ndim; } + + __host__ bool operator==(const Shape& other) const { + if (ndim != other.ndim) return false; + for (size_t i = 0; i < ndim; i++) { + if (dims[i] != other.dims[i]) return false; + } + return true; + } + + __host__ bool operator!=(const Shape& other) const { + return !(*this == other); + } +}; std::string format_shape(const Shape& shape) { std::string result;