From 9ff214d75933ca1bd040ed50f57091417c384497 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 26 Nov 2025 20:47:55 +0100 Subject: [PATCH] Refactor CUDA kernels and tensor operations for type generality --- include/backend/cuda/cuda.cuh | 208 +++++++--- .../cuda/kernels/activation_functions.cuh | 25 +- include/backend/cuda/kernels/convolution.cuh | 9 +- include/backend/cuda/kernels/matmul.cuh | 181 +++------ include/backend/cuda/kernels/pool.cuh | 10 +- include/tensor.hpp | 2 +- .../cuda/kernels/activation_functions.cu | 19 +- src/backends/cuda/kernels/convolution.cu | 17 +- src/backends/cuda/kernels/matmul.cu | 146 ++++++- src/backends/cuda/kernels/pool.cu | 30 +- src/backends/cuda/layer_ops.cu | 373 +++++++++++++++--- src/backends/cuda/tensor_ops.cu | 90 ++++- src/layers/dense.cpp | 1 + src/tensor.cpp | 4 + 14 files changed, 818 insertions(+), 297 deletions(-) diff --git a/include/backend/cuda/cuda.cuh b/include/backend/cuda/cuda.cuh index db548bc..002d6b9 100644 --- a/include/backend/cuda/cuda.cuh +++ b/include/backend/cuda/cuda.cuh @@ -8,53 +8,60 @@ #ifndef BLOCK_SIZE #define BLOCK_SIZE 128 -#endif // BLOCK_SIZE +#endif // BLOCK_SIZE /** * @brief CUDA error checking macro - * + * */ -#define CUDA_CHECK(call) \ -do { \ - cudaError_t result = call; \ - if (result != cudaSuccess) { \ - fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", \ - __FILE__, __LINE__, static_cast(result), \ - cudaGetErrorString(result), #call); \ - exit(EXIT_FAILURE); \ - } \ -} while (0) +#define CUDA_CHECK(call) \ + do { \ + cudaError_t result = call; \ + if (result != cudaSuccess) { \ + fprintf( \ + stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", __FILE__, \ + __LINE__, static_cast(result), \ + cudaGetErrorString(result), #call \ + ); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) namespace CUDANet::Backends { +template +struct cuda_dtype_map; + +template <> +struct cuda_dtype_map { + using type = float; +}; + class CUDA : public Backend { - private: - int device_id; - std::set supported_dtypes; public: CUDA(const BackendConfig& config); - bool supports_dtype(DType dtype) const override; - void set_default_dtype(DType dtype) override; + bool supports_dtype(DType dtype) const override; + void set_default_dtype(DType dtype) override; DType get_default_dtype() const override; static bool is_cuda_available(); - void initialize(); + void initialize(); // Memory management void* allocate(size_t bytes) override; void deallocate(void* ptr) override; - // Tensor ops + // Tensor ops dispatchers void print(const CUDANet::Tensor& input) override; void zero(CUDANet::Tensor& input) override; - void fill(CUDANet::Tensor &input, int value) override; + void fill(CUDANet::Tensor& input, int value) override; void copy_to_device(CUDANet::Tensor& tensor, void* data, size_t size) override; void sum(const CUDANet::Tensor& input, CUDANet::Tensor& sum) override; void max(const CUDANet::Tensor& input, CUDANet::Tensor& max) override; - // Layer ops + // Layer ops dispatchers void relu(CUDANet::Tensor& tensor) override; void sigmoid(CUDANet::Tensor& tensor) override; void softmax( @@ -67,7 +74,7 @@ class CUDA : public Backend { const CUDANet::Tensor& weights, const CUDANet::Tensor& biases, const CUDANet::Tensor& input, - CUDANet::Tensor& output, + CUDANet::Tensor& output, const size_t input_size, const size_t output_size ) override; @@ -76,43 +83,43 @@ class CUDA : public Backend { const CUDANet::Tensor& weights, const CUDANet::Tensor& biases, const CUDANet::Tensor& input, - CUDANet::Tensor& output, - const CUDANet::Shape in_shape, - const CUDANet::Shape padding_shape, - const CUDANet::Shape kernel_shape, - const CUDANet::Shape stride_shape, - const CUDANet::Shape out_shape + CUDANet::Tensor& output, + const CUDANet::Shape in_shape, + const CUDANet::Shape padding_shape, + const CUDANet::Shape kernel_shape, + const CUDANet::Shape stride_shape, + const CUDANet::Shape out_shape ) override; CUDANet::Tensor& max_pool2d( const CUDANet::Tensor& input, - CUDANet::Tensor& output, - CUDANet::Shape input_shape, - CUDANet::Shape pool_shape, - CUDANet::Shape stride_shape, - CUDANet::Shape padding_shape, - CUDANet::Shape output_shape + CUDANet::Tensor& output, + CUDANet::Shape input_shape, + CUDANet::Shape pool_shape, + CUDANet::Shape stride_shape, + CUDANet::Shape padding_shape, + CUDANet::Shape output_shape ) override; CUDANet::Tensor& avg_pool2d( const CUDANet::Tensor& input, - CUDANet::Tensor& output, - CUDANet::Shape input_shape, - CUDANet::Shape pool_shape, - CUDANet::Shape stride_shape, - CUDANet::Shape padding_shape, - CUDANet::Shape output_shape + CUDANet::Tensor& output, + CUDANet::Shape input_shape, + CUDANet::Shape pool_shape, + CUDANet::Shape stride_shape, + CUDANet::Shape padding_shape, + CUDANet::Shape output_shape ) override; CUDANet::Tensor& batch_norm( const CUDANet::Tensor& input, - CUDANet::Tensor& output, - CUDANet::Shape input_shape, - CUDANet::Tensor& weights, - CUDANet::Tensor& biases, - CUDANet::Tensor& running_mean, - CUDANet::Tensor& running_var, - CUDANet::Tensor& epsilon + CUDANet::Tensor& output, + CUDANet::Shape input_shape, + CUDANet::Tensor& weights, + CUDANet::Tensor& biases, + CUDANet::Tensor& running_mean, + CUDANet::Tensor& running_var, + CUDANet::Tensor& epsilon ) override; CUDANet::Tensor& concat( @@ -126,6 +133,111 @@ class CUDA : public Backend { CUDANet::Tensor& input_b, CUDANet::Tensor& output ) override; + + private: + int device_id; + std::set supported_dtypes; + + // Tensor ops template impls + template + void print_impl(const CUDANet::Tensor& input); + + template + void fill_impl(CUDANet::Tensor& input, int value); + + template + void copy_to_device_impl(CUDANet::Tensor& tensor, void* data, size_t size); + + template + void sum_impl(const CUDANet::Tensor& input, CUDANet::Tensor& sum); + + template + void max_impl(const CUDANet::Tensor& input, CUDANet::Tensor& max); + + // Layer ops template impls + template + void relu_impl(CUDANet::Tensor& tensor); + + template + void sigmoid_impl(CUDANet::Tensor& tensor); + + template + void softmax_impl( + CUDANet::Tensor& tensor, + CUDANet::Tensor& temp_max, + CUDANet::Tensor& temp_sum + ); + + template + CUDANet::Tensor& dense_impl( + const CUDANet::Tensor& weights, + const CUDANet::Tensor& biases, + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + const size_t input_size, + const size_t output_size + ); + + template + CUDANet::Tensor& conv2d_impl( + const CUDANet::Tensor& weights, + const CUDANet::Tensor& biases, + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + const CUDANet::Shape in_shape, + const CUDANet::Shape padding_shape, + const CUDANet::Shape kernel_shape, + const CUDANet::Shape stride_shape, + const CUDANet::Shape out_shape + ); + + template + CUDANet::Tensor& max_pool2d_impl( + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + CUDANet::Shape input_shape, + CUDANet::Shape pool_shape, + CUDANet::Shape stride_shape, + CUDANet::Shape padding_shape, + CUDANet::Shape output_shape + ); + + template + CUDANet::Tensor& avg_pool2d_impl( + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + CUDANet::Shape input_shape, + CUDANet::Shape pool_shape, + CUDANet::Shape stride_shape, + CUDANet::Shape padding_shape, + CUDANet::Shape output_shape + ); + + template + CUDANet::Tensor& batch_norm_impl( + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + CUDANet::Shape input_shape, + CUDANet::Tensor& weights, + CUDANet::Tensor& biases, + CUDANet::Tensor& running_mean, + CUDANet::Tensor& running_var, + CUDANet::Tensor& epsilon + ); + + template + CUDANet::Tensor& concat_impl( + CUDANet::Tensor& input_a, + CUDANet::Tensor& input_b, + CUDANet::Tensor& output + ); + + template + CUDANet::Tensor& add_impl( + CUDANet::Tensor& input_a, + CUDANet::Tensor& input_b, + CUDANet::Tensor& output + ); }; -} // namespace CUDANet::Backend \ No newline at end of file +} // namespace CUDANet::Backends \ No newline at end of file diff --git a/include/backend/cuda/kernels/activation_functions.cuh b/include/backend/cuda/kernels/activation_functions.cuh index e5f2def..411fefb 100644 --- a/include/backend/cuda/kernels/activation_functions.cuh +++ b/include/backend/cuda/kernels/activation_functions.cuh @@ -4,29 +4,18 @@ namespace CUDANet::Kernels { -/** - * @brief Sigmoid activation function kernel - * - * @param src Pointer to the source array - * @param dst Pointer to the destination array - * @param len Length of the arrays - */ + +template __global__ void sigmoid( - const float* __restrict__ src, - float* __restrict__ dst, + const T* __restrict__ src, + T* __restrict__ dst, const unsigned int len ); -/** - * @brief Relu activation function kernel - * - * @param src Pointer to the source array - * @param dst Pointer to the destination array - * @param len Length of the arrays - */ +template __global__ void relu( - const float* __restrict__ src, - float* __restrict__ dst, + const T* __restrict__ src, + T* __restrict__ dst, const unsigned int len ); diff --git a/include/backend/cuda/kernels/convolution.cuh b/include/backend/cuda/kernels/convolution.cuh index 96d5ace..c2c6203 100644 --- a/include/backend/cuda/kernels/convolution.cuh +++ b/include/backend/cuda/kernels/convolution.cuh @@ -5,11 +5,12 @@ namespace CUDANet::Kernels { +template __global__ void convolution( - const float* __restrict__ d_input, - const float* __restrict__ d_kernel, - const float* __restrict__ d_bias, - float* __restrict__ d_output, + const T* __restrict__ d_input, + const T* __restrict__ d_kernel, + const T* __restrict__ d_bias, + T* __restrict__ d_output, const Shape input_shape, const Shape padding_shape, const Shape kernel_shape, diff --git a/include/backend/cuda/kernels/matmul.cuh b/include/backend/cuda/kernels/matmul.cuh index 55b5856..56d5ce2 100644 --- a/include/backend/cuda/kernels/matmul.cuh +++ b/include/backend/cuda/kernels/matmul.cuh @@ -4,188 +4,105 @@ namespace CUDANet::Kernels { -/** - * @brief Matrix vector multiplication kernel - * - * @param d_matrix Device pointer to matrix - * @param d_vector Device pointer to vector - * @param d_output Device pointer to output vector - * @param w Width of the matrix - * @param h Height of the matrix - */ +template __global__ void mat_vec_mul( - const float* __restrict__ d_matrix, - const float* __restrict__ d_vector, - float* __restrict__ d_output, + const T* __restrict__ d_matrix, + const T* __restrict__ d_vector, + T* __restrict__ d_output, const unsigned int w, const unsigned int h ); -/** - * @brief Vector vector addition kernel - * - * @param d_vector1 Device pointer to first vector - * @param d_vector2 Device pointer to second vector - * @param d_output Device pointer to output vector - * @param w Length of the vectors - */ +template __global__ void vec_vec_add( - const float* __restrict__ d_vector1, - const float* __restrict__ d_vector2, - float* __restrict__ d_output, + const T* __restrict__ d_vector1, + const T* __restrict__ d_vector2, + T* __restrict__ d_output, const unsigned int w ); -/** - * @brief Vector vector subtraction kernel - * - * @param d_vector1 - * @param d_vector2 - * @param d_output - * @param w - * @return __global__ - */ +template __global__ void vec_vec_sub( - const float* __restrict__ d_vector1, - const float* __restrict__ d_vector2, - float* __restrict__ d_output, + const T* __restrict__ d_vector1, + const T* __restrict__ d_vector2, + T* __restrict__ d_output, const unsigned int w ); +template __global__ void vec_vec_mul( - const float* __restrict__ d_vector1, - const float* __restrict__ d_vector2, - float* __restrict__ d_output, + const T* __restrict__ d_vector1, + const T* __restrict__ d_vector2, + T* __restrict__ d_output, const unsigned int w ); -/** - * @brief Sub scalar from each element of the vector - * - * @param d_vector - * @param d_scalar - * @param d_output - * @param w - * @return __global__ - */ +template __global__ void vec_scalar_sub( - const float* __restrict__ d_src, - float* __restrict__ d_out, - const float* __restrict__ d_scalar, + const T* __restrict__ d_src, + T* __restrict__ d_out, + const T* __restrict__ d_scalar, const unsigned int len ); -/** - * @brief Add scalar to each element of the vector - * - * @param d_src - * @param d_out - * @param d_scalar - * @param len - * @return __global__ - */ +template __global__ void vec_scalar_add( - const float* __restrict__ d_src, - float* __restrict__ d_out, - const float* __restrict__ d_scalar, + const T* __restrict__ d_src, + T* __restrict__ d_out, + const T* __restrict__ d_scalar, const unsigned int len ); -/** - * @brief Divide each element of the vector by a scalar - * - * @param src Pointer to the source array - * @param dst Pointer to the destination array - * @param len Length of the arrays - */ +template __global__ void vec_scalar_div( - const float* __restrict__ d_src, - float* __restrict__ d_out, - const float* __restrict__ d_scalar, + const T* __restrict__ d_src, + T* __restrict__ d_out, + const T* __restrict__ d_scalar, const unsigned int len ); -/** - * @brief Multiply each element of the vector by a scalar - * - * @param d_src - * @param d_out - * @param d_scalar - * @param len - * @return __global__ - */ +template __global__ void vec_scalar_mul( - const float* __restrict__ d_src, - float* __restrict__ d_out, - const float* __restrict__ d_scalar, + const T* __restrict__ d_src, + T* __restrict__ d_out, + const T* __restrict__ d_scalar, const unsigned int len ); -/** - * @brief Exponentiate each element of the vector - * - * @param src Pointer to the source array - * @param dst Pointer to the destination array - * @param len Length of the arrays - */ +template __global__ void vec_exp( - const float* __restrict__ src, - float* __restrict__ dst, + const T* __restrict__ src, + T* __restrict__ dst, const unsigned int len ); -/** - * @brief Compute the square root of each element of the vector - * - * @param src Device pointer to source vector - * @param dst Device pointer to destination vector - * @param len Length of the vector - */ +template __global__ void vec_sqrt( - const float* __restrict__ src, - float* __restrict__ dst, + const T* __restrict__ src, + T* __restrict__ dst, const unsigned int len ); -/** - * @brief Scales the vector by 1/sqrt(scale + epsilon) - * - * @param src Device pointer to source vector - * @param dst Device pointer to destination vector - * @param scale Scale - * @param epsilon Epsilon - * @param len Length of the vector - */ +template __global__ void vec_scale( - const float* __restrict__ src, - float* __restrict__ dst, - const float* __restrict__ scale, - const float* epsilon, + const T* __restrict__ src, + T* __restrict__ dst, + const T* __restrict__ scale, + const T* epsilon, const unsigned int len ); -/** - * @brief Max reduction kernel - * - * @param d_vector Device pointer to vector - * @param d_output Device pointer to output vector - */ +template __global__ void max_reduce( - const float* __restrict__ d_vector, - float* __restrict__ d_output, + const T* __restrict__ d_vector, + T* __restrict__ d_output, const unsigned int len ); -/** - * @brief - * - * @param d_vector Device pointer to vector - * @param d_output Device pointer to output vector - * @param len Length of the vector - */ +template __global__ void sum_reduce( - const float* __restrict__ d_vector, - float* __restrict__ d_output, + const T* __restrict__ d_vector, + T* __restrict__ d_output, const unsigned int len ); diff --git a/include/backend/cuda/kernels/pool.cuh b/include/backend/cuda/kernels/pool.cuh index ed92256..e6787f5 100644 --- a/include/backend/cuda/kernels/pool.cuh +++ b/include/backend/cuda/kernels/pool.cuh @@ -5,9 +5,10 @@ namespace CUDANet::Kernels { +template __global__ void max_pool( - const float* __restrict__ d_input, - float* __restrict__ d_output, + const T* __restrict__ d_input, + T* __restrict__ d_output, const Shape input_shape, const Shape output_shape, const Shape pool_shape, @@ -15,9 +16,10 @@ __global__ void max_pool( const Shape padding_shape ); +template __global__ void avg_pool( - const float* __restrict__ d_input, - float* __restrict__ d_output, + const T* __restrict__ d_input, + T* __restrict__ d_output, const Shape input_shape, const Shape output_shape, const Shape pool_shape, diff --git a/include/tensor.hpp b/include/tensor.hpp index fb22bec..0fb3e2a 100644 --- a/include/tensor.hpp +++ b/include/tensor.hpp @@ -33,7 +33,7 @@ public: ~Tensor(); - DType get_dtype(); + DType get_dtype() const; size_t size() const; size_t numel() const; diff --git a/src/backends/cuda/kernels/activation_functions.cu b/src/backends/cuda/kernels/activation_functions.cu index 444fd3a..e2049a9 100644 --- a/src/backends/cuda/kernels/activation_functions.cu +++ b/src/backends/cuda/kernels/activation_functions.cu @@ -2,10 +2,18 @@ using namespace CUDANet; -__global__ void Kernels::sigmoid( +template +__global__ void Kernels::sigmoid( const float* __restrict__ src, float* __restrict__ dst, const unsigned int len +); + +template +__global__ void Kernels::sigmoid( + const T* __restrict__ src, + T* __restrict__ dst, + const unsigned int len ) { int stride = gridDim.x * blockDim.x; int tid = blockDim.x * blockIdx.x + threadIdx.x; @@ -15,10 +23,17 @@ __global__ void Kernels::sigmoid( } } -__global__ void Kernels::relu( +template __global__ void Kernels::relu( const float* __restrict__ src, float* __restrict__ dst, const unsigned int len +); + +template +__global__ void Kernels::relu( + const T* __restrict__ src, + T* __restrict__ dst, + const unsigned int len ) { int stride = gridDim.x * blockDim.x; int tid = blockDim.x * blockIdx.x + threadIdx.x; diff --git a/src/backends/cuda/kernels/convolution.cu b/src/backends/cuda/kernels/convolution.cu index c815dfe..85abf98 100644 --- a/src/backends/cuda/kernels/convolution.cu +++ b/src/backends/cuda/kernels/convolution.cu @@ -4,7 +4,7 @@ using namespace CUDANet; -__global__ void Kernels::convolution( +template __global__ void Kernels::convolution( const float* __restrict__ d_input, const float* __restrict__ d_kernel, const float* __restrict__ d_bias, @@ -14,6 +14,19 @@ __global__ void Kernels::convolution( const Shape kernel_shape, const Shape stride_shape, const Shape output_shape +); + +template +__global__ void Kernels::convolution( + const T* __restrict__ d_input, + const T* __restrict__ d_kernel, + const T* __restrict__ d_bias, + T* __restrict__ d_output, + const Shape input_shape, + const Shape padding_shape, + const Shape kernel_shape, + const Shape stride_shape, + const Shape output_shape ) { int j = blockDim.x * blockIdx.x + threadIdx.x; int i = blockDim.y * blockIdx.y + threadIdx.y; @@ -23,7 +36,7 @@ __global__ void Kernels::convolution( return; } - float sum = 0.0f; + T sum = static_cast(0); // Iterate over kernel and input matrix for (int c = 0; c < input_shape[2]; c++) { diff --git a/src/backends/cuda/kernels/matmul.cu b/src/backends/cuda/kernels/matmul.cu index 23b039e..7e5f840 100644 --- a/src/backends/cuda/kernels/matmul.cu +++ b/src/backends/cuda/kernels/matmul.cu @@ -3,17 +3,26 @@ using namespace CUDANet; -__global__ void Kernels::mat_vec_mul( +template __global__ void Kernels::mat_vec_mul( const float* __restrict__ d_matrix, const float* __restrict__ d_vector, float* __restrict__ d_output, const unsigned int w, const unsigned int h +); + +template +__global__ void Kernels::mat_vec_mul( + const T* __restrict__ d_matrix, + const T* __restrict__ d_vector, + T* __restrict__ d_output, + const unsigned int w, + const unsigned int h ) { int tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid < h) { - float temp = 0.0f; + T temp = static_cast(0); for (unsigned int j = 0; j < w; j++) { temp += d_matrix[tid * w + j] * d_vector[j]; @@ -23,11 +32,19 @@ __global__ void Kernels::mat_vec_mul( } } -__global__ void Kernels::vec_vec_add( +template __global__ void Kernels::vec_vec_add( const float* __restrict__ d_vector1, const float* __restrict__ d_vector2, float* __restrict__ d_output, const unsigned int w +); + +template +__global__ void Kernels::vec_vec_add( + const T* __restrict__ d_vector1, + const T* __restrict__ d_vector2, + T* __restrict__ d_output, + const unsigned int w ) { int tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid >= w) { @@ -36,11 +53,19 @@ __global__ void Kernels::vec_vec_add( d_output[tid] = d_vector1[tid] + d_vector2[tid]; } -__global__ void Kernels::vec_vec_sub( +template __global__ void Kernels::vec_vec_sub( const float* __restrict__ d_vector1, const float* __restrict__ d_vector2, float* __restrict__ d_output, const unsigned int w +); + +template +__global__ void Kernels::vec_vec_sub( + const T* __restrict__ d_vector1, + const T* __restrict__ d_vector2, + T* __restrict__ d_output, + const unsigned int w ) { int tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid >= w) { @@ -49,11 +74,19 @@ __global__ void Kernels::vec_vec_sub( d_output[tid] = d_vector1[tid] - d_vector2[tid]; } -__global__ void Kernels::vec_vec_mul( +template __global__ void Kernels::vec_vec_mul( const float* __restrict__ d_vector1, const float* __restrict__ d_vector2, float* __restrict__ d_output, const unsigned int w +); + +template +__global__ void Kernels::vec_vec_mul( + const T* __restrict__ d_vector1, + const T* __restrict__ d_vector2, + T* __restrict__ d_output, + const unsigned int w ) { int tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid >= w) { @@ -62,11 +95,19 @@ __global__ void Kernels::vec_vec_mul( d_output[tid] = d_vector1[tid] * d_vector2[tid]; } -__global__ void Kernels::vec_scalar_sub( +template __global__ void Kernels::vec_scalar_sub( const float* __restrict__ d_src, float* __restrict__ d_out, const float* __restrict__ d_scalar, const unsigned int len +); + +template +__global__ void Kernels::vec_scalar_sub( + const T* __restrict__ d_src, + T* __restrict__ d_out, + const T* __restrict__ d_scalar, + const unsigned int len ) { int tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid >= len) { @@ -75,11 +116,19 @@ __global__ void Kernels::vec_scalar_sub( d_out[tid] = d_src[tid] - *d_scalar; } -__global__ void Kernels::vec_scalar_add( +template __global__ void Kernels::vec_scalar_add( const float* __restrict__ d_src, float* __restrict__ d_out, const float* __restrict__ d_scalar, const unsigned int len +); + +template +__global__ void Kernels::vec_scalar_add( + const T* __restrict__ d_src, + T* __restrict__ d_out, + const T* __restrict__ d_scalar, + const unsigned int len ) { int tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid >= len) { @@ -88,11 +137,19 @@ __global__ void Kernels::vec_scalar_add( d_out[tid] = d_src[tid] + *d_scalar; } -__global__ void Kernels::vec_scalar_div( +template __global__ void Kernels::vec_scalar_div( const float* __restrict__ d_src, float* __restrict__ d_out, const float* __restrict__ d_scalar, const unsigned int len +); + +template +__global__ void Kernels::vec_scalar_div( + const T* __restrict__ d_src, + T* __restrict__ d_out, + const T* __restrict__ d_scalar, + const unsigned int len ) { int tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid >= len) { @@ -101,11 +158,19 @@ __global__ void Kernels::vec_scalar_div( d_out[tid] = d_src[tid] / *d_scalar; } -__global__ void Kernels::vec_scalar_mul( +template __global__ void Kernels::vec_scalar_mul( const float* __restrict__ d_src, float* __restrict__ d_out, const float* __restrict__ d_scalar, const unsigned int len +); + +template +__global__ void Kernels::vec_scalar_mul( + const T* __restrict__ d_src, + T* __restrict__ d_out, + const T* __restrict__ d_scalar, + const unsigned int len ) { int tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid >= len) { @@ -114,64 +179,98 @@ __global__ void Kernels::vec_scalar_mul( d_out[tid] = d_src[tid] * *d_scalar; } -__global__ void Kernels::vec_exp( +template __global__ void Kernels::vec_exp( const float* __restrict__ src, float* __restrict__ dst, const unsigned int len +); + +template +__global__ void Kernels::vec_exp( + const T* __restrict__ src, + T* __restrict__ dst, + const unsigned int len ) { int stride = gridDim.x * blockDim.x; int tid = blockDim.x * blockIdx.x + threadIdx.x; for (int i = tid; i < len; i += stride) { + // TODO: separate implementation for __half dst[i] = expf(src[i]); } } -__global__ void Kernels::vec_sqrt( +template __global__ void Kernels::vec_sqrt( const float* __restrict__ src, float* __restrict__ dst, const unsigned int len +); + +template +__global__ void Kernels::vec_sqrt( + const T* __restrict__ src, + T* __restrict__ dst, + const unsigned int len ) { int stride = gridDim.x * blockDim.x; - int tid = blockDim.x * blockIdx.x + threadIdx.x; + int tid = blockDim.x * blockIdx.x + threadIdx.x; for (int i = tid; i < len; i += stride) { + // TODO: separate implementation for __half dst[i] = sqrtf(src[i]); } } -__global__ void Kernels::vec_scale( +template __global__ void Kernels::vec_scale( const float* __restrict__ src, float* __restrict__ dst, const float* __restrict__ scale, const float* epsilon, const unsigned int len +); + +template +__global__ void Kernels::vec_scale( + const T* __restrict__ src, + T* __restrict__ dst, + const T* __restrict__ scale, + const T* epsilon, + const unsigned int len ) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < len) { + // TODO: separate implementation for __half float inv_std = rsqrtf(*scale + *epsilon); dst[idx] = src[idx] * inv_std; } } -__global__ void Kernels::max_reduce( +template __global__ void Kernels::max_reduce( const float* __restrict__ d_vector, float* __restrict__ d_output, const unsigned int len +); + +template +__global__ void Kernels::max_reduce( + const T* __restrict__ d_vector, + T* __restrict__ d_output, + const unsigned int len ) { - __shared__ float shared_max[BLOCK_SIZE]; + __shared__ T shared_max[BLOCK_SIZE]; int i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { shared_max[threadIdx.x] = d_vector[i]; } else { shared_max[threadIdx.x] = -INFINITY; - } + } __syncthreads(); for (int s = blockDim.x / 2; s > 0; s >>= 1) { if (threadIdx.x < s) { + // TODO: separate implementation for __half shared_max[threadIdx.x] = fmaxf(shared_max[threadIdx.x], shared_max[threadIdx.x + s]); } __syncthreads(); @@ -182,18 +281,25 @@ __global__ void Kernels::max_reduce( } } -__global__ void Kernels::sum_reduce( +template __global__ void Kernels::sum_reduce( const float* __restrict__ d_vector, float* __restrict__ d_output, const unsigned int len +); + +template +__global__ void Kernels::sum_reduce( + const T* __restrict__ d_vector, + T* __restrict__ d_output, + const unsigned int len ) { - __shared__ float partial_sum[BLOCK_SIZE]; + __shared__ T partial_sum[BLOCK_SIZE]; int i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { partial_sum[threadIdx.x] = d_vector[i]; } else { - partial_sum[threadIdx.x] = 0.0f; + partial_sum[threadIdx.x] = static_cast(0); } __syncthreads(); @@ -208,4 +314,4 @@ __global__ void Kernels::sum_reduce( if (threadIdx.x == 0) { d_output[blockIdx.x] = partial_sum[0]; } -} \ No newline at end of file +} diff --git a/src/backends/cuda/kernels/pool.cu b/src/backends/cuda/kernels/pool.cu index 6e5c02a..227ef8f 100644 --- a/src/backends/cuda/kernels/pool.cu +++ b/src/backends/cuda/kernels/pool.cu @@ -3,7 +3,7 @@ using namespace CUDANet; -__global__ void Kernels::max_pool( +template __global__ void Kernels::max_pool( const float* __restrict__ d_input, float* __restrict__ d_output, const Shape input_shape, @@ -11,6 +11,17 @@ __global__ void Kernels::max_pool( const Shape pool_shape, const Shape stride_shape, const Shape padding_shape +); + +template +__global__ void Kernels::max_pool( + const T* __restrict__ d_input, + T* __restrict__ d_output, + const Shape input_shape, + const Shape output_shape, + const Shape pool_shape, + const Shape stride_shape, + const Shape padding_shape ) { int j = blockDim.x * blockIdx.x + threadIdx.x; int i = blockDim.y * blockIdx.y + threadIdx.y; @@ -20,7 +31,7 @@ __global__ void Kernels::max_pool( return; } - float max = 0.0f; + T max = static_cast(0); for (int k = 0; k < pool_shape[0]; k++) { for (int l = 0; l < pool_shape[1]; l++) { @@ -43,7 +54,7 @@ __global__ void Kernels::max_pool( max; } -__global__ void Kernels::avg_pool( +template __global__ void Kernels::avg_pool( const float* __restrict__ d_input, float* __restrict__ d_output, const Shape input_shape, @@ -51,6 +62,17 @@ __global__ void Kernels::avg_pool( const Shape pool_shape, const Shape stride_shape, const Shape padding_shape +); + +template +__global__ void Kernels::avg_pool( + const T* __restrict__ d_input, + T* __restrict__ d_output, + const Shape input_shape, + const Shape output_shape, + const Shape pool_shape, + const Shape stride_shape, + const Shape padding_shape ) { int j = blockDim.x * blockIdx.x + threadIdx.x; int i = blockDim.y * blockIdx.y + threadIdx.y; @@ -60,7 +82,7 @@ __global__ void Kernels::avg_pool( return; } - float sum = 0.0f; + T sum = static_cast(0); for (int k = 0; k < pool_shape[0]; k++) { for (int l = 0; l < pool_shape[1]; l++) { diff --git a/src/backends/cuda/layer_ops.cu b/src/backends/cuda/layer_ops.cu index 0c71033..812b042 100644 --- a/src/backends/cuda/layer_ops.cu +++ b/src/backends/cuda/layer_ops.cu @@ -7,24 +7,70 @@ using namespace CUDANet::Backends; void CUDA::relu(Tensor& tensor) { + switch (tensor.get_dtype()) { + case DType::FLOAT32: + relu_impl(tensor); + break; + + default: + throw std::runtime_error("Unsupported dtype"); + break; + } +} + +template void CUDA::relu_impl(Tensor& tensor); + +template +void CUDA::relu_impl(Tensor& tensor) { int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE; Kernels::relu<<>>( - tensor.data(), tensor.data(), tensor.numel() + tensor.data(), tensor.data(), tensor.numel() ); CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaDeviceSynchronize()); } -void CUDA::sigmoid(Tensor& tensor) { +void CUDA::sigmoid(CUDANet::Tensor& tensor) { + switch (tensor.get_dtype()) { + case DType::FLOAT32: + sigmoid_impl(tensor); + break; + + default: + throw std::runtime_error("Unsupported dtype"); + break; + } +} + +template void CUDA::sigmoid_impl(Tensor& tensor); + +template +void CUDA::sigmoid_impl(CUDANet::Tensor& tensor) { int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE; Kernels::sigmoid<<>>( - tensor.data(), tensor.data(), tensor.numel() + tensor.data(), tensor.data(), tensor.numel() ); CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaDeviceSynchronize()); } void CUDA::softmax(Tensor& tensor, Tensor& temp_max, Tensor& temp_sum) { + switch (tensor.get_dtype()) { + case DType::FLOAT32: + softmax_impl(tensor, temp_max, temp_sum); + break; + + default: + throw std::runtime_error("Unsupported dtype"); + break; + } +} + +template void +CUDA::softmax_impl(Tensor& tensor, Tensor& temp_max, Tensor& temp_sum); + +template +void CUDA::softmax_impl(Tensor& tensor, Tensor& temp_max, Tensor& temp_sum) { int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE; // Find max value @@ -32,14 +78,13 @@ void CUDA::softmax(Tensor& tensor, Tensor& temp_max, Tensor& temp_sum) { // Subtract max value to improve numerical stability Kernels::vec_scalar_sub<<>>( - tensor.data(), tensor.data(), temp_max.data(), - tensor.numel() + tensor.data(), tensor.data(), temp_max.data(), tensor.numel() ); CUDA_CHECK(cudaGetLastError()); // Compute exponentials Kernels::vec_exp<<>>( - tensor.data(), tensor.data(), tensor.numel() + tensor.data(), tensor.data(), tensor.numel() ); CUDA_CHECK(cudaGetLastError()); @@ -47,8 +92,7 @@ void CUDA::softmax(Tensor& tensor, Tensor& temp_max, Tensor& temp_sum) { sum(tensor, temp_sum); Kernels::vec_scalar_div<<>>( - tensor.data(), tensor.data(), temp_sum.data(), - tensor.numel() + tensor.data(), tensor.data(), temp_sum.data(), tensor.numel() ); CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaDeviceSynchronize()); @@ -61,20 +105,50 @@ CUDANet::Tensor& CUDA::dense( CUDANet::Tensor& output, const size_t input_size, const size_t output_size +) { + switch (input.get_dtype()) { + case DType::FLOAT32: + return dense_impl( + weights, biases, input, output, input_size, output_size + ); + break; + + default: + throw std::runtime_error("Unsupported dtype"); + break; + } +} + +template CUDANet::Tensor& CUDA::dense_impl( + const CUDANet::Tensor& weights, + const CUDANet::Tensor& biases, + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + const size_t input_size, + const size_t output_size +); + +template +CUDANet::Tensor& CUDA::dense_impl( + const CUDANet::Tensor& weights, + const CUDANet::Tensor& biases, + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + const size_t input_size, + const size_t output_size ) { auto forwardGridSize = (std::max(input_size, output_size) + BLOCK_SIZE - 1) / BLOCK_SIZE; auto biasGridSize = (output_size + BLOCK_SIZE - 1) / BLOCK_SIZE; Kernels::mat_vec_mul<<>>( - weights.data(), input.data(), output.data(), - input_size, output_size + weights.data(), input.data(), output.data(), input_size, + output_size ); CUDA_CHECK(cudaGetLastError()); Kernels::vec_vec_add<<>>( - biases.data(), output.data(), output.data(), - output_size + biases.data(), output.data(), output.data(), output_size ); CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaDeviceSynchronize()); @@ -92,6 +166,44 @@ CUDANet::Tensor& CUDA::conv2d( const CUDANet::Shape kernel_shape, const CUDANet::Shape stride_shape, const CUDANet::Shape out_shape +) { + switch (input.get_dtype()) { + case DType::FLOAT32: + return conv2d_impl( + weights, biases, input, output, in_shape, padding_shape, + kernel_shape, stride_shape, out_shape + ); + break; + + default: + throw std::runtime_error("Unsupported dtype"); + break; + } +} + +template CUDANet::Tensor& CUDA::conv2d_impl( + const CUDANet::Tensor& weights, + const CUDANet::Tensor& biases, + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + const CUDANet::Shape in_shape, + const CUDANet::Shape padding_shape, + const CUDANet::Shape kernel_shape, + const CUDANet::Shape stride_shape, + const CUDANet::Shape out_shape +); + +template +CUDANet::Tensor& CUDA::conv2d_impl( + const CUDANet::Tensor& weights, + const CUDANet::Tensor& biases, + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + const CUDANet::Shape in_shape, + const CUDANet::Shape padding_shape, + const CUDANet::Shape kernel_shape, + const CUDANet::Shape stride_shape, + const CUDANet::Shape out_shape ) { dim3 block(8, 8, 8); dim3 grid( @@ -101,9 +213,8 @@ CUDANet::Tensor& CUDA::conv2d( ); Kernels::convolution<<>>( - input.data(), weights.data(), biases.data(), - output.data(), in_shape, padding_shape, kernel_shape, - stride_shape, out_shape + input.data(), weights.data(), biases.data(), output.data(), + in_shape, padding_shape, kernel_shape, stride_shape, out_shape ); CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaDeviceSynchronize()); @@ -113,12 +224,46 @@ CUDANet::Tensor& CUDA::conv2d( CUDANet::Tensor& CUDA::max_pool2d( const CUDANet::Tensor& input, - CUDANet::Tensor& output, - CUDANet::Shape input_shape, - CUDANet::Shape pool_shape, - CUDANet::Shape stride_shape, - CUDANet::Shape padding_shape, - CUDANet::Shape output_shape + CUDANet::Tensor& output, + CUDANet::Shape input_shape, + CUDANet::Shape pool_shape, + CUDANet::Shape stride_shape, + CUDANet::Shape padding_shape, + CUDANet::Shape output_shape +) { + switch (input.get_dtype()) { + case DType::FLOAT32: + return max_pool2d_impl( + input, output, input_shape, pool_shape, stride_shape, + padding_shape, output_shape + ); + break; + + default: + throw std::runtime_error("Unsupported dtype"); + break; + } +} + +template CUDANet::Tensor& CUDA::max_pool2d_impl( + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + CUDANet::Shape input_shape, + CUDANet::Shape pool_shape, + CUDANet::Shape stride_shape, + CUDANet::Shape padding_shape, + CUDANet::Shape output_shape +); + +template +CUDANet::Tensor& CUDA::max_pool2d_impl( + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + CUDANet::Shape input_shape, + CUDANet::Shape pool_shape, + CUDANet::Shape stride_shape, + CUDANet::Shape padding_shape, + CUDANet::Shape output_shape ) { dim3 block(8, 8, 8); dim3 grid( @@ -128,8 +273,8 @@ CUDANet::Tensor& CUDA::max_pool2d( ); Kernels::max_pool<<>>( - input.data(), output.data(), input_shape, output_shape, pool_shape, - stride_shape, padding_shape + input.data(), output.data(), input_shape, output_shape, + pool_shape, stride_shape, padding_shape ); CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaDeviceSynchronize()); @@ -139,12 +284,46 @@ CUDANet::Tensor& CUDA::max_pool2d( CUDANet::Tensor& CUDA::avg_pool2d( const CUDANet::Tensor& input, - CUDANet::Tensor& output, - CUDANet::Shape input_shape, - CUDANet::Shape pool_shape, - CUDANet::Shape stride_shape, - CUDANet::Shape padding_shape, - CUDANet::Shape output_shape + CUDANet::Tensor& output, + CUDANet::Shape input_shape, + CUDANet::Shape pool_shape, + CUDANet::Shape stride_shape, + CUDANet::Shape padding_shape, + CUDANet::Shape output_shape +) { + switch (input.get_dtype()) { + case DType::FLOAT32: + return avg_pool2d_impl( + input, output, input_shape, pool_shape, stride_shape, + padding_shape, output_shape + ); + break; + + default: + throw std::runtime_error("Unsupported dtype"); + break; + } +} + +template CUDANet::Tensor& CUDA::avg_pool2d_impl( + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + CUDANet::Shape input_shape, + CUDANet::Shape pool_shape, + CUDANet::Shape stride_shape, + CUDANet::Shape padding_shape, + CUDANet::Shape output_shape +); + +template +CUDANet::Tensor& CUDA::avg_pool2d_impl( + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + CUDANet::Shape input_shape, + CUDANet::Shape pool_shape, + CUDANet::Shape stride_shape, + CUDANet::Shape padding_shape, + CUDANet::Shape output_shape ) { dim3 block(8, 8, 8); dim3 grid( @@ -154,8 +333,8 @@ CUDANet::Tensor& CUDA::avg_pool2d( ); Kernels::avg_pool<<>>( - input.data(), output.data(), input_shape, output_shape, pool_shape, - stride_shape, padding_shape + input.data(), output.data(), input_shape, output_shape, + pool_shape, stride_shape, padding_shape ); CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaDeviceSynchronize()); @@ -165,48 +344,84 @@ CUDANet::Tensor& CUDA::avg_pool2d( CUDANet::Tensor& CUDA::batch_norm( const CUDANet::Tensor& input, - CUDANet::Tensor& output, - CUDANet::Shape input_shape, - CUDANet::Tensor& weights, - CUDANet::Tensor& biases, - CUDANet::Tensor& running_mean, - CUDANet::Tensor& running_var, - CUDANet::Tensor& epsilon + CUDANet::Tensor& output, + CUDANet::Shape input_shape, + CUDANet::Tensor& weights, + CUDANet::Tensor& biases, + CUDANet::Tensor& running_mean, + CUDANet::Tensor& running_var, + CUDANet::Tensor& epsilon +) { + switch (input.get_dtype()) { + case DType::FLOAT32: + return batch_norm_impl( + input, output, input_shape, weights, biases, running_mean, + running_var, epsilon + ); + break; + + default: + throw std::runtime_error("Unsupported dtype"); + break; + } +} + +template CUDANet::Tensor& CUDA::batch_norm_impl( + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + CUDANet::Shape input_shape, + CUDANet::Tensor& weights, + CUDANet::Tensor& biases, + CUDANet::Tensor& running_mean, + CUDANet::Tensor& running_var, + CUDANet::Tensor& epsilon +); + +template +CUDANet::Tensor& CUDA::batch_norm_impl( + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + CUDANet::Shape input_shape, + CUDANet::Tensor& weights, + CUDANet::Tensor& biases, + CUDANet::Tensor& running_mean, + CUDANet::Tensor& running_var, + CUDANet::Tensor& epsilon ) { auto gridSize = (input_shape[0] * input_shape[1] + BLOCK_SIZE - 1) / BLOCK_SIZE; - for (int i = 0; i < input_shape[2]; i++) { // Subtract mean from input Kernels::vec_scalar_sub<<>>( - input.data() + i * input_shape[0] * input_shape[1], - output.data() + i * input_shape[0] * input_shape[1], - &running_mean.data()[i], input_shape[0] * input_shape[1] + input.data() + i * input_shape[0] * input_shape[1], + output.data() + i * input_shape[0] * input_shape[1], + &running_mean.data()[i], input_shape[0] * input_shape[1] ); CUDA_CHECK(cudaGetLastError()); // Divide by sqrt(running_var + epsilon) Kernels::vec_scale<<>>( - output.data() + i * input_shape[0] * input_shape[1], - output.data() + i * input_shape[0] * input_shape[1], - &running_var.data()[i], epsilon.data(), input_shape[0] * input_shape[1] + output.data() + i * input_shape[0] * input_shape[1], + output.data() + i * input_shape[0] * input_shape[1], + &running_var.data()[i], epsilon.data(), + input_shape[0] * input_shape[1] ); CUDA_CHECK(cudaGetLastError()); // Multiply by weights Kernels::vec_scalar_mul<<>>( - output.data() + i * input_shape[0] * input_shape[1], - output.data() + i * input_shape[0] * input_shape[1], &weights.data()[i], - input_shape[0] * input_shape[1] + output.data() + i * input_shape[0] * input_shape[1], + output.data() + i * input_shape[0] * input_shape[1], + &weights.data()[i], input_shape[0] * input_shape[1] ); CUDA_CHECK(cudaGetLastError()); // Add biases Kernels::vec_scalar_add<<>>( - output.data() + i * input_shape[0] * input_shape[1], - output.data() + i * input_shape[0] * input_shape[1], &biases.data()[i], - input_shape[0] * input_shape[1] + output.data() + i * input_shape[0] * input_shape[1], + output.data() + i * input_shape[0] * input_shape[1], + &biases.data()[i], input_shape[0] * input_shape[1] ); CUDA_CHECK(cudaGetLastError()); } @@ -218,14 +433,39 @@ CUDANet::Tensor& CUDA::concat( CUDANet::Tensor& input_a, CUDANet::Tensor& input_b, CUDANet::Tensor& output +) { + switch (input_a.get_dtype()) { + case DType::FLOAT32: + return concat_impl( + input_a, input_b, output + ); + break; + + default: + throw std::runtime_error("Unsupported dtype"); + break; + } +} + +template CUDANet::Tensor& CUDA::concat_impl( + CUDANet::Tensor& input_a, + CUDANet::Tensor& input_b, + CUDANet::Tensor& output +); + +template +CUDANet::Tensor& CUDA::concat_impl( + CUDANet::Tensor& input_a, + CUDANet::Tensor& input_b, + CUDANet::Tensor& output ) { CUDA_CHECK(cudaMemcpy( - output.data(), input_a.data(), input_a.size(), + output.data(), input_a.data(), input_a.size(), cudaMemcpyDeviceToDevice )); CUDA_CHECK(cudaMemcpy( - output.data() + input_a.numel(), input_b.data(), input_b.size(), + output.data() + input_a.numel(), input_b.data(), input_b.size(), cudaMemcpyDeviceToDevice )); @@ -239,11 +479,36 @@ CUDANet::Tensor& CUDA::add( CUDANet::Tensor& input_a, CUDANet::Tensor& input_b, CUDANet::Tensor& output +) { + switch (input_a.get_dtype()) { + case DType::FLOAT32: + return add_impl( + input_a, input_b, output + ); + break; + + default: + throw std::runtime_error("Unsupported dtype"); + break; + } +} + +template CUDANet::Tensor& CUDA::add_impl( + CUDANet::Tensor& input_a, + CUDANet::Tensor& input_b, + CUDANet::Tensor& output +); + +template +CUDANet::Tensor& CUDA::add_impl( + CUDANet::Tensor& input_a, + CUDANet::Tensor& input_b, + CUDANet::Tensor& output ) { auto gridSize = (input_a.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE; Kernels::vec_vec_add<<>>( - input_a.data(), input_b.data(), output.data(), input_a.numel() + input_a.data(), input_b.data(), output.data(), input_a.numel() ); CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaDeviceSynchronize()); diff --git a/src/backends/cuda/tensor_ops.cu b/src/backends/cuda/tensor_ops.cu index 979034a..a39d3ce 100644 --- a/src/backends/cuda/tensor_ops.cu +++ b/src/backends/cuda/tensor_ops.cu @@ -7,11 +7,26 @@ using namespace CUDANet::Backends; void CUDA::print(const CUDANet::Tensor &input) { + switch (input.get_dtype()) { + case DType::FLOAT32: + print_impl(input); + break; + + default: + throw std::runtime_error("Unsupported dtype"); + break; + } +} + +template void CUDA::print_impl (const CUDANet::Tensor &input); + +template +void CUDA::print_impl(const CUDANet::Tensor &input) { auto length = input.numel(); - std::vector h_vec(input.numel()); + std::vector h_vec(input.numel()); CUDA_CHECK(cudaMemcpy( - h_vec.data(), input.data(), sizeof(float) * length, cudaMemcpyDeviceToHost + h_vec.data(), input.data(), sizeof(T) * length, cudaMemcpyDeviceToHost )); for (int i = 0; i < length; ++i) { @@ -26,27 +41,71 @@ void CUDA::zero(CUDANet::Tensor &input) { } void CUDA::fill(CUDANet::Tensor &input, int value) { - CUDA_CHECK(cudaMemset(input.data(), value, sizeof(float) * input.numel())); + switch (input.get_dtype()) { + case DType::FLOAT32: + fill_impl(input, value); + break; + default: + throw std::runtime_error("Unsupported dtype"); + break; + } +} + +template void CUDA::fill_impl(CUDANet::Tensor &input, int value); + +template +void CUDA::fill_impl(CUDANet::Tensor &input, int value) { + CUDA_CHECK(cudaMemset(input.data(), value, sizeof(T) * input.numel())); } void CUDA::copy_to_device(CUDANet::Tensor &tensor, void *data, size_t size) { - CUDA_CHECK(cudaMemcpy(tensor.data(), data, size, cudaMemcpyHostToDevice)); + switch (tensor.get_dtype()) { + case DType::FLOAT32: + copy_to_device_impl(tensor, data, size); + break; + + default: + throw std::runtime_error("Unsupported dtype"); + break; + } +} + +template void CUDA::copy_to_device_impl(CUDANet::Tensor &tensor, void *data, size_t size); + +template +void CUDA::copy_to_device_impl(CUDANet::Tensor &tensor, void *data, size_t size) { + CUDA_CHECK(cudaMemcpy(tensor.data(), data, size, cudaMemcpyHostToDevice)); } void CUDA::sum(const CUDANet::Tensor &input, CUDANet::Tensor &sum) { + switch (input.get_dtype()) { + case DType::FLOAT32: + sum_impl(input, sum); + break; + + default: + throw std::runtime_error("Unsupported dtype"); + break; + } +} + +template void CUDA::sum_impl(const CUDANet::Tensor &input, CUDANet::Tensor &sum); + +template +void CUDA::sum_impl(const CUDANet::Tensor &input, CUDANet::Tensor &sum) { auto length = input.numel(); const int gridSize = ( + BLOCK_SIZE - 1) / BLOCK_SIZE; CUDANet::Kernels::sum_reduce<<>>( - input.data(), sum.data(), length + input.data(), sum.data(), length ); CUDA_CHECK(cudaGetLastError()); int remaining = gridSize; while (remaining > 1) { int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE; - CUDANet::Kernels::sum_reduce<<>>(sum.data(), sum.data(), remaining); + CUDANet::Kernels::sum_reduce<<>>(sum.data(), sum.data(), remaining); CUDA_CHECK(cudaGetLastError()); remaining = blocks_needed; @@ -54,17 +113,32 @@ void CUDA::sum(const CUDANet::Tensor &input, CUDANet::Tensor &sum) { } void CUDA::max(const CUDANet::Tensor &input, CUDANet::Tensor &max) { + switch (input.get_dtype()) { + case DType::FLOAT32: + max_impl(input, max); + break; + + default: + throw std::runtime_error("Unsupported dtype"); + break; + } +} + +template void CUDA::max_impl(const CUDANet::Tensor &input, CUDANet::Tensor &max); + +template +void CUDA::max_impl(const CUDANet::Tensor &input, CUDANet::Tensor &max) { auto length = input.numel(); const int grid_size = (length + BLOCK_SIZE - 1) / BLOCK_SIZE; - Kernels::max_reduce<<>>(input.data(), max.data(), length); + Kernels::max_reduce<<>>(input.data(), max.data(), length); CUDA_CHECK(cudaGetLastError()); int remaining = grid_size; while (remaining > 1) { int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE; - CUDANet::Kernels::max_reduce<<>>(max.data(), max.data(), remaining); + CUDANet::Kernels::max_reduce<<>>(max.data(), max.data(), remaining); CUDA_CHECK(cudaGetLastError()); remaining = blocks_needed; diff --git a/src/layers/dense.cpp b/src/layers/dense.cpp index b273c09..54b02fe 100644 --- a/src/layers/dense.cpp +++ b/src/layers/dense.cpp @@ -56,6 +56,7 @@ size_t Dense::output_size() { return out_shape[0]; }; +// TODO: Use dtype void Dense::set_weights(void* input) { weights.set_data(static_cast(input)); } diff --git a/src/tensor.cpp b/src/tensor.cpp index fc42057..16ee4de 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -80,6 +80,10 @@ Tensor::~Tensor() { } } +DType Tensor::get_dtype() const { + return dtype; +} + size_t Tensor::numel() const { return total_elms; }