Refactor CUDA kernels and tensor operations for type generality

This commit is contained in:
2025-11-26 20:47:55 +01:00
parent 13d3d38b68
commit 9ff214d759
14 changed files with 818 additions and 297 deletions

View File

@@ -5,9 +5,10 @@
namespace CUDANet::Kernels {
template <typename T>
__global__ void max_pool(
const float* __restrict__ d_input,
float* __restrict__ d_output,
const T* __restrict__ d_input,
T* __restrict__ d_output,
const Shape input_shape,
const Shape output_shape,
const Shape pool_shape,
@@ -15,9 +16,10 @@ __global__ void max_pool(
const Shape padding_shape
);
template <typename T>
__global__ void avg_pool(
const float* __restrict__ d_input,
float* __restrict__ d_output,
const T* __restrict__ d_input,
T* __restrict__ d_output,
const Shape input_shape,
const Shape output_shape,
const Shape pool_shape,