mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-23 06:44:24 +00:00
Refactor CUDA kernels and tensor operations for type generality
This commit is contained in:
@@ -5,9 +5,10 @@
|
||||
|
||||
namespace CUDANet::Kernels {
|
||||
|
||||
template <typename T>
|
||||
__global__ void max_pool(
|
||||
const float* __restrict__ d_input,
|
||||
float* __restrict__ d_output,
|
||||
const T* __restrict__ d_input,
|
||||
T* __restrict__ d_output,
|
||||
const Shape input_shape,
|
||||
const Shape output_shape,
|
||||
const Shape pool_shape,
|
||||
@@ -15,9 +16,10 @@ __global__ void max_pool(
|
||||
const Shape padding_shape
|
||||
);
|
||||
|
||||
template <typename T>
|
||||
__global__ void avg_pool(
|
||||
const float* __restrict__ d_input,
|
||||
float* __restrict__ d_output,
|
||||
const T* __restrict__ d_input,
|
||||
T* __restrict__ d_output,
|
||||
const Shape input_shape,
|
||||
const Shape output_shape,
|
||||
const Shape pool_shape,
|
||||
|
||||
Reference in New Issue
Block a user