mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-22 14:24:22 +00:00
Implement custom Shape struct with __device__ support
This commit is contained in:
@@ -1,11 +1,71 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#ifndef __host__
|
||||||
|
#define __host__
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef __device__
|
||||||
|
#define __device__
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <format>
|
#include <format>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace CUDANet {
|
namespace CUDANet {
|
||||||
|
|
||||||
typedef std::vector<size_t> Shape;
|
struct Shape {
|
||||||
|
static constexpr size_t MAX_DIMS = 8;
|
||||||
|
|
||||||
|
size_t dims[MAX_DIMS];
|
||||||
|
size_t ndim;
|
||||||
|
|
||||||
|
__host__ __device__ Shape() : ndim(0) {
|
||||||
|
for (int i = 0; i < MAX_DIMS; i++) dims[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ Shape(std::initializer_list<size_t> list) : ndim(list.size()) {
|
||||||
|
if (ndim > MAX_DIMS) {
|
||||||
|
throw std::runtime_error("Too many dimensions");
|
||||||
|
}
|
||||||
|
size_t i = 0;
|
||||||
|
for (auto val : list) {
|
||||||
|
dims[i++] = val;
|
||||||
|
}
|
||||||
|
for (; i < MAX_DIMS; i++) dims[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ Shape(const std::vector<size_t>& vec) : ndim(vec.size()) {
|
||||||
|
if (ndim > MAX_DIMS) {
|
||||||
|
throw std::runtime_error("Too many dimensions");
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < ndim; i++) {
|
||||||
|
dims[i] = vec[i];
|
||||||
|
}
|
||||||
|
for (size_t i = ndim; i < MAX_DIMS; i++) dims[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ size_t operator[](size_t idx) const {
|
||||||
|
return dims[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ size_t& operator[](size_t idx) {
|
||||||
|
return dims[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ size_t size() const { return ndim; }
|
||||||
|
|
||||||
|
__host__ bool operator==(const Shape& other) const {
|
||||||
|
if (ndim != other.ndim) return false;
|
||||||
|
for (size_t i = 0; i < ndim; i++) {
|
||||||
|
if (dims[i] != other.dims[i]) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ bool operator!=(const Shape& other) const {
|
||||||
|
return !(*this == other);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
std::string format_shape(const Shape& shape) {
|
std::string format_shape(const Shape& shape) {
|
||||||
std::string result;
|
std::string result;
|
||||||
|
|||||||
Reference in New Issue
Block a user