Refactor CUDA kernels and tensor operations for type generality

This commit is contained in:
2025-11-26 20:47:55 +01:00
parent 13d3d38b68
commit 9ff214d759
14 changed files with 818 additions and 297 deletions

View File

@@ -8,53 +8,60 @@
#ifndef BLOCK_SIZE
#define BLOCK_SIZE 128
#endif // BLOCK_SIZE
#endif // BLOCK_SIZE
/**
* @brief CUDA error checking macro
*
*
*/
#define CUDA_CHECK(call) \
do { \
cudaError_t result = call; \
if (result != cudaSuccess) { \
fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", \
__FILE__, __LINE__, static_cast<unsigned int>(result), \
cudaGetErrorString(result), #call); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define CUDA_CHECK(call) \
do { \
cudaError_t result = call; \
if (result != cudaSuccess) { \
fprintf( \
stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", __FILE__, \
__LINE__, static_cast<unsigned int>(result), \
cudaGetErrorString(result), #call \
); \
exit(EXIT_FAILURE); \
} \
} while (0)
namespace CUDANet::Backends {
template <DType dtype>
struct cuda_dtype_map;
template <>
struct cuda_dtype_map<DType::FLOAT32> {
using type = float;
};
class CUDA : public Backend {
private:
int device_id;
std::set<DType> supported_dtypes;
public:
CUDA(const BackendConfig& config);
bool supports_dtype(DType dtype) const override;
void set_default_dtype(DType dtype) override;
bool supports_dtype(DType dtype) const override;
void set_default_dtype(DType dtype) override;
DType get_default_dtype() const override;
static bool is_cuda_available();
void initialize();
void initialize();
// Memory management
void* allocate(size_t bytes) override;
void deallocate(void* ptr) override;
// Tensor ops
// Tensor ops dispatchers
void print(const CUDANet::Tensor& input) override;
void zero(CUDANet::Tensor& input) override;
void fill(CUDANet::Tensor &input, int value) override;
void fill(CUDANet::Tensor& input, int value) override;
void
copy_to_device(CUDANet::Tensor& tensor, void* data, size_t size) override;
void sum(const CUDANet::Tensor& input, CUDANet::Tensor& sum) override;
void max(const CUDANet::Tensor& input, CUDANet::Tensor& max) override;
// Layer ops
// Layer ops dispatchers
void relu(CUDANet::Tensor& tensor) override;
void sigmoid(CUDANet::Tensor& tensor) override;
void softmax(
@@ -67,7 +74,7 @@ class CUDA : public Backend {
const CUDANet::Tensor& weights,
const CUDANet::Tensor& biases,
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
CUDANet::Tensor& output,
const size_t input_size,
const size_t output_size
) override;
@@ -76,43 +83,43 @@ class CUDA : public Backend {
const CUDANet::Tensor& weights,
const CUDANet::Tensor& biases,
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
const CUDANet::Shape in_shape,
const CUDANet::Shape padding_shape,
const CUDANet::Shape kernel_shape,
const CUDANet::Shape stride_shape,
const CUDANet::Shape out_shape
CUDANet::Tensor& output,
const CUDANet::Shape in_shape,
const CUDANet::Shape padding_shape,
const CUDANet::Shape kernel_shape,
const CUDANet::Shape stride_shape,
const CUDANet::Shape out_shape
) override;
CUDANet::Tensor& max_pool2d(
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::Shape output_shape
CUDANet::Tensor& output,
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::Shape output_shape
) override;
CUDANet::Tensor& avg_pool2d(
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::Shape output_shape
CUDANet::Tensor& output,
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::Shape output_shape
) override;
CUDANet::Tensor& batch_norm(
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
CUDANet::Shape input_shape,
CUDANet::Tensor& weights,
CUDANet::Tensor& biases,
CUDANet::Tensor& running_mean,
CUDANet::Tensor& running_var,
CUDANet::Tensor& epsilon
CUDANet::Tensor& output,
CUDANet::Shape input_shape,
CUDANet::Tensor& weights,
CUDANet::Tensor& biases,
CUDANet::Tensor& running_mean,
CUDANet::Tensor& running_var,
CUDANet::Tensor& epsilon
) override;
CUDANet::Tensor& concat(
@@ -126,6 +133,111 @@ class CUDA : public Backend {
CUDANet::Tensor& input_b,
CUDANet::Tensor& output
) override;
private:
int device_id;
std::set<DType> supported_dtypes;
// Tensor ops template impls
template <typename T>
void print_impl(const CUDANet::Tensor& input);
template <typename T>
void fill_impl(CUDANet::Tensor& input, int value);
template <typename T>
void copy_to_device_impl(CUDANet::Tensor& tensor, void* data, size_t size);
template <typename T>
void sum_impl(const CUDANet::Tensor& input, CUDANet::Tensor& sum);
template <typename T>
void max_impl(const CUDANet::Tensor& input, CUDANet::Tensor& max);
// Layer ops template impls
template <typename T>
void relu_impl(CUDANet::Tensor& tensor);
template <typename T>
void sigmoid_impl(CUDANet::Tensor& tensor);
template <typename T>
void softmax_impl(
CUDANet::Tensor& tensor,
CUDANet::Tensor& temp_max,
CUDANet::Tensor& temp_sum
);
template <typename T>
CUDANet::Tensor& dense_impl(
const CUDANet::Tensor& weights,
const CUDANet::Tensor& biases,
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
const size_t input_size,
const size_t output_size
);
template <typename T>
CUDANet::Tensor& conv2d_impl(
const CUDANet::Tensor& weights,
const CUDANet::Tensor& biases,
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
const CUDANet::Shape in_shape,
const CUDANet::Shape padding_shape,
const CUDANet::Shape kernel_shape,
const CUDANet::Shape stride_shape,
const CUDANet::Shape out_shape
);
template <typename T>
CUDANet::Tensor& max_pool2d_impl(
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::Shape output_shape
);
template <typename T>
CUDANet::Tensor& avg_pool2d_impl(
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::Shape output_shape
);
template <typename T>
CUDANet::Tensor& batch_norm_impl(
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
CUDANet::Shape input_shape,
CUDANet::Tensor& weights,
CUDANet::Tensor& biases,
CUDANet::Tensor& running_mean,
CUDANet::Tensor& running_var,
CUDANet::Tensor& epsilon
);
template <typename T>
CUDANet::Tensor& concat_impl(
CUDANet::Tensor& input_a,
CUDANet::Tensor& input_b,
CUDANet::Tensor& output
);
template <typename T>
CUDANet::Tensor& add_impl(
CUDANet::Tensor& input_a,
CUDANet::Tensor& input_b,
CUDANet::Tensor& output
);
};
} // namespace CUDANet::Backend
} // namespace CUDANet::Backends

View File

@@ -4,29 +4,18 @@
namespace CUDANet::Kernels {
/**
* @brief Sigmoid activation function kernel
*
* @param src Pointer to the source array
* @param dst Pointer to the destination array
* @param len Length of the arrays
*/
template <typename T>
__global__ void sigmoid(
const float* __restrict__ src,
float* __restrict__ dst,
const T* __restrict__ src,
T* __restrict__ dst,
const unsigned int len
);
/**
* @brief Relu activation function kernel
*
* @param src Pointer to the source array
* @param dst Pointer to the destination array
* @param len Length of the arrays
*/
template <typename T>
__global__ void relu(
const float* __restrict__ src,
float* __restrict__ dst,
const T* __restrict__ src,
T* __restrict__ dst,
const unsigned int len
);

View File

@@ -5,11 +5,12 @@
namespace CUDANet::Kernels {
template <typename T>
__global__ void convolution(
const float* __restrict__ d_input,
const float* __restrict__ d_kernel,
const float* __restrict__ d_bias,
float* __restrict__ d_output,
const T* __restrict__ d_input,
const T* __restrict__ d_kernel,
const T* __restrict__ d_bias,
T* __restrict__ d_output,
const Shape input_shape,
const Shape padding_shape,
const Shape kernel_shape,

View File

@@ -4,188 +4,105 @@
namespace CUDANet::Kernels {
/**
* @brief Matrix vector multiplication kernel
*
* @param d_matrix Device pointer to matrix
* @param d_vector Device pointer to vector
* @param d_output Device pointer to output vector
* @param w Width of the matrix
* @param h Height of the matrix
*/
template <typename T>
__global__ void mat_vec_mul(
const float* __restrict__ d_matrix,
const float* __restrict__ d_vector,
float* __restrict__ d_output,
const T* __restrict__ d_matrix,
const T* __restrict__ d_vector,
T* __restrict__ d_output,
const unsigned int w,
const unsigned int h
);
/**
* @brief Vector vector addition kernel
*
* @param d_vector1 Device pointer to first vector
* @param d_vector2 Device pointer to second vector
* @param d_output Device pointer to output vector
* @param w Length of the vectors
*/
template <typename T>
__global__ void vec_vec_add(
const float* __restrict__ d_vector1,
const float* __restrict__ d_vector2,
float* __restrict__ d_output,
const T* __restrict__ d_vector1,
const T* __restrict__ d_vector2,
T* __restrict__ d_output,
const unsigned int w
);
/**
* @brief Vector vector subtraction kernel
*
* @param d_vector1
* @param d_vector2
* @param d_output
* @param w
* @return __global__
*/
template <typename T>
__global__ void vec_vec_sub(
const float* __restrict__ d_vector1,
const float* __restrict__ d_vector2,
float* __restrict__ d_output,
const T* __restrict__ d_vector1,
const T* __restrict__ d_vector2,
T* __restrict__ d_output,
const unsigned int w
);
template <typename T>
__global__ void vec_vec_mul(
const float* __restrict__ d_vector1,
const float* __restrict__ d_vector2,
float* __restrict__ d_output,
const T* __restrict__ d_vector1,
const T* __restrict__ d_vector2,
T* __restrict__ d_output,
const unsigned int w
);
/**
* @brief Sub scalar from each element of the vector
*
* @param d_vector
* @param d_scalar
* @param d_output
* @param w
* @return __global__
*/
template <typename T>
__global__ void vec_scalar_sub(
const float* __restrict__ d_src,
float* __restrict__ d_out,
const float* __restrict__ d_scalar,
const T* __restrict__ d_src,
T* __restrict__ d_out,
const T* __restrict__ d_scalar,
const unsigned int len
);
/**
* @brief Add scalar to each element of the vector
*
* @param d_src
* @param d_out
* @param d_scalar
* @param len
* @return __global__
*/
template <typename T>
__global__ void vec_scalar_add(
const float* __restrict__ d_src,
float* __restrict__ d_out,
const float* __restrict__ d_scalar,
const T* __restrict__ d_src,
T* __restrict__ d_out,
const T* __restrict__ d_scalar,
const unsigned int len
);
/**
* @brief Divide each element of the vector by a scalar
*
* @param src Pointer to the source array
* @param dst Pointer to the destination array
* @param len Length of the arrays
*/
template <typename T>
__global__ void vec_scalar_div(
const float* __restrict__ d_src,
float* __restrict__ d_out,
const float* __restrict__ d_scalar,
const T* __restrict__ d_src,
T* __restrict__ d_out,
const T* __restrict__ d_scalar,
const unsigned int len
);
/**
* @brief Multiply each element of the vector by a scalar
*
* @param d_src
* @param d_out
* @param d_scalar
* @param len
* @return __global__
*/
template <typename T>
__global__ void vec_scalar_mul(
const float* __restrict__ d_src,
float* __restrict__ d_out,
const float* __restrict__ d_scalar,
const T* __restrict__ d_src,
T* __restrict__ d_out,
const T* __restrict__ d_scalar,
const unsigned int len
);
/**
* @brief Exponentiate each element of the vector
*
* @param src Pointer to the source array
* @param dst Pointer to the destination array
* @param len Length of the arrays
*/
template <typename T>
__global__ void vec_exp(
const float* __restrict__ src,
float* __restrict__ dst,
const T* __restrict__ src,
T* __restrict__ dst,
const unsigned int len
);
/**
* @brief Compute the square root of each element of the vector
*
* @param src Device pointer to source vector
* @param dst Device pointer to destination vector
* @param len Length of the vector
*/
template <typename T>
__global__ void vec_sqrt(
const float* __restrict__ src,
float* __restrict__ dst,
const T* __restrict__ src,
T* __restrict__ dst,
const unsigned int len
);
/**
* @brief Scales the vector by 1/sqrt(scale + epsilon)
*
* @param src Device pointer to source vector
* @param dst Device pointer to destination vector
* @param scale Scale
* @param epsilon Epsilon
* @param len Length of the vector
*/
template <typename T>
__global__ void vec_scale(
const float* __restrict__ src,
float* __restrict__ dst,
const float* __restrict__ scale,
const float* epsilon,
const T* __restrict__ src,
T* __restrict__ dst,
const T* __restrict__ scale,
const T* epsilon,
const unsigned int len
);
/**
* @brief Max reduction kernel
*
* @param d_vector Device pointer to vector
* @param d_output Device pointer to output vector
*/
template <typename T>
__global__ void max_reduce(
const float* __restrict__ d_vector,
float* __restrict__ d_output,
const T* __restrict__ d_vector,
T* __restrict__ d_output,
const unsigned int len
);
/**
* @brief
*
* @param d_vector Device pointer to vector
* @param d_output Device pointer to output vector
* @param len Length of the vector
*/
template <typename T>
__global__ void sum_reduce(
const float* __restrict__ d_vector,
float* __restrict__ d_output,
const T* __restrict__ d_vector,
T* __restrict__ d_output,
const unsigned int len
);

View File

@@ -5,9 +5,10 @@
namespace CUDANet::Kernels {
template <typename T>
__global__ void max_pool(
const float* __restrict__ d_input,
float* __restrict__ d_output,
const T* __restrict__ d_input,
T* __restrict__ d_output,
const Shape input_shape,
const Shape output_shape,
const Shape pool_shape,
@@ -15,9 +16,10 @@ __global__ void max_pool(
const Shape padding_shape
);
template <typename T>
__global__ void avg_pool(
const float* __restrict__ d_input,
float* __restrict__ d_output,
const T* __restrict__ d_input,
T* __restrict__ d_output,
const Shape input_shape,
const Shape output_shape,
const Shape pool_shape,

View File

@@ -33,7 +33,7 @@ public:
~Tensor();
DType get_dtype();
DType get_dtype() const;
size_t size() const;
size_t numel() const;

View File

@@ -2,10 +2,18 @@
using namespace CUDANet;
__global__ void Kernels::sigmoid(
template
__global__ void Kernels::sigmoid<float>(
const float* __restrict__ src,
float* __restrict__ dst,
const unsigned int len
);
template <typename T>
__global__ void Kernels::sigmoid(
const T* __restrict__ src,
T* __restrict__ dst,
const unsigned int len
) {
int stride = gridDim.x * blockDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
@@ -15,10 +23,17 @@ __global__ void Kernels::sigmoid(
}
}
__global__ void Kernels::relu(
template __global__ void Kernels::relu<float>(
const float* __restrict__ src,
float* __restrict__ dst,
const unsigned int len
);
template <typename T>
__global__ void Kernels::relu(
const T* __restrict__ src,
T* __restrict__ dst,
const unsigned int len
) {
int stride = gridDim.x * blockDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;

View File

@@ -4,7 +4,7 @@
using namespace CUDANet;
__global__ void Kernels::convolution(
template __global__ void Kernels::convolution<float>(
const float* __restrict__ d_input,
const float* __restrict__ d_kernel,
const float* __restrict__ d_bias,
@@ -14,6 +14,19 @@ __global__ void Kernels::convolution(
const Shape kernel_shape,
const Shape stride_shape,
const Shape output_shape
);
template <typename T>
__global__ void Kernels::convolution(
const T* __restrict__ d_input,
const T* __restrict__ d_kernel,
const T* __restrict__ d_bias,
T* __restrict__ d_output,
const Shape input_shape,
const Shape padding_shape,
const Shape kernel_shape,
const Shape stride_shape,
const Shape output_shape
) {
int j = blockDim.x * blockIdx.x + threadIdx.x;
int i = blockDim.y * blockIdx.y + threadIdx.y;
@@ -23,7 +36,7 @@ __global__ void Kernels::convolution(
return;
}
float sum = 0.0f;
T sum = static_cast<t>(0);
// Iterate over kernel and input matrix
for (int c = 0; c < input_shape[2]; c++) {

View File

@@ -3,17 +3,26 @@
using namespace CUDANet;
__global__ void Kernels::mat_vec_mul(
template __global__ void Kernels::mat_vec_mul<float>(
const float* __restrict__ d_matrix,
const float* __restrict__ d_vector,
float* __restrict__ d_output,
const unsigned int w,
const unsigned int h
);
template <typename T>
__global__ void Kernels::mat_vec_mul(
const T* __restrict__ d_matrix,
const T* __restrict__ d_vector,
T* __restrict__ d_output,
const unsigned int w,
const unsigned int h
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid < h) {
float temp = 0.0f;
T temp = static_cast<T>(0);
for (unsigned int j = 0; j < w; j++) {
temp += d_matrix[tid * w + j] * d_vector[j];
@@ -23,11 +32,19 @@ __global__ void Kernels::mat_vec_mul(
}
}
__global__ void Kernels::vec_vec_add(
template __global__ void Kernels::vec_vec_add<float>(
const float* __restrict__ d_vector1,
const float* __restrict__ d_vector2,
float* __restrict__ d_output,
const unsigned int w
);
template <typename T>
__global__ void Kernels::vec_vec_add(
const T* __restrict__ d_vector1,
const T* __restrict__ d_vector2,
T* __restrict__ d_output,
const unsigned int w
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= w) {
@@ -36,11 +53,19 @@ __global__ void Kernels::vec_vec_add(
d_output[tid] = d_vector1[tid] + d_vector2[tid];
}
__global__ void Kernels::vec_vec_sub(
template __global__ void Kernels::vec_vec_sub<float>(
const float* __restrict__ d_vector1,
const float* __restrict__ d_vector2,
float* __restrict__ d_output,
const unsigned int w
);
template <typename T>
__global__ void Kernels::vec_vec_sub(
const T* __restrict__ d_vector1,
const T* __restrict__ d_vector2,
T* __restrict__ d_output,
const unsigned int w
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= w) {
@@ -49,11 +74,19 @@ __global__ void Kernels::vec_vec_sub(
d_output[tid] = d_vector1[tid] - d_vector2[tid];
}
__global__ void Kernels::vec_vec_mul(
template __global__ void Kernels::vec_vec_mul<float>(
const float* __restrict__ d_vector1,
const float* __restrict__ d_vector2,
float* __restrict__ d_output,
const unsigned int w
);
template <typename T>
__global__ void Kernels::vec_vec_mul(
const T* __restrict__ d_vector1,
const T* __restrict__ d_vector2,
T* __restrict__ d_output,
const unsigned int w
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= w) {
@@ -62,11 +95,19 @@ __global__ void Kernels::vec_vec_mul(
d_output[tid] = d_vector1[tid] * d_vector2[tid];
}
__global__ void Kernels::vec_scalar_sub(
template __global__ void Kernels::vec_scalar_sub<float>(
const float* __restrict__ d_src,
float* __restrict__ d_out,
const float* __restrict__ d_scalar,
const unsigned int len
);
template <typename T>
__global__ void Kernels::vec_scalar_sub(
const T* __restrict__ d_src,
T* __restrict__ d_out,
const T* __restrict__ d_scalar,
const unsigned int len
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) {
@@ -75,11 +116,19 @@ __global__ void Kernels::vec_scalar_sub(
d_out[tid] = d_src[tid] - *d_scalar;
}
__global__ void Kernels::vec_scalar_add(
template __global__ void Kernels::vec_scalar_add<float>(
const float* __restrict__ d_src,
float* __restrict__ d_out,
const float* __restrict__ d_scalar,
const unsigned int len
);
template <typename T>
__global__ void Kernels::vec_scalar_add(
const T* __restrict__ d_src,
T* __restrict__ d_out,
const T* __restrict__ d_scalar,
const unsigned int len
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) {
@@ -88,11 +137,19 @@ __global__ void Kernels::vec_scalar_add(
d_out[tid] = d_src[tid] + *d_scalar;
}
__global__ void Kernels::vec_scalar_div(
template __global__ void Kernels::vec_scalar_div<float>(
const float* __restrict__ d_src,
float* __restrict__ d_out,
const float* __restrict__ d_scalar,
const unsigned int len
);
template <typename T>
__global__ void Kernels::vec_scalar_div(
const T* __restrict__ d_src,
T* __restrict__ d_out,
const T* __restrict__ d_scalar,
const unsigned int len
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) {
@@ -101,11 +158,19 @@ __global__ void Kernels::vec_scalar_div(
d_out[tid] = d_src[tid] / *d_scalar;
}
__global__ void Kernels::vec_scalar_mul(
template __global__ void Kernels::vec_scalar_mul<float>(
const float* __restrict__ d_src,
float* __restrict__ d_out,
const float* __restrict__ d_scalar,
const unsigned int len
);
template <typename T>
__global__ void Kernels::vec_scalar_mul(
const T* __restrict__ d_src,
T* __restrict__ d_out,
const T* __restrict__ d_scalar,
const unsigned int len
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) {
@@ -114,64 +179,98 @@ __global__ void Kernels::vec_scalar_mul(
d_out[tid] = d_src[tid] * *d_scalar;
}
__global__ void Kernels::vec_exp(
template __global__ void Kernels::vec_exp<float>(
const float* __restrict__ src,
float* __restrict__ dst,
const unsigned int len
);
template <typename T>
__global__ void Kernels::vec_exp(
const T* __restrict__ src,
T* __restrict__ dst,
const unsigned int len
) {
int stride = gridDim.x * blockDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < len; i += stride) {
// TODO: separate implementation for __half
dst[i] = expf(src[i]);
}
}
__global__ void Kernels::vec_sqrt(
template __global__ void Kernels::vec_sqrt<float>(
const float* __restrict__ src,
float* __restrict__ dst,
const unsigned int len
);
template <typename T>
__global__ void Kernels::vec_sqrt(
const T* __restrict__ src,
T* __restrict__ dst,
const unsigned int len
) {
int stride = gridDim.x * blockDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < len; i += stride) {
// TODO: separate implementation for __half
dst[i] = sqrtf(src[i]);
}
}
__global__ void Kernels::vec_scale(
template __global__ void Kernels::vec_scale<float>(
const float* __restrict__ src,
float* __restrict__ dst,
const float* __restrict__ scale,
const float* epsilon,
const unsigned int len
);
template <typename T>
__global__ void Kernels::vec_scale(
const T* __restrict__ src,
T* __restrict__ dst,
const T* __restrict__ scale,
const T* epsilon,
const unsigned int len
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < len) {
// TODO: separate implementation for __half
float inv_std = rsqrtf(*scale + *epsilon);
dst[idx] = src[idx] * inv_std;
}
}
__global__ void Kernels::max_reduce(
template __global__ void Kernels::max_reduce<float>(
const float* __restrict__ d_vector,
float* __restrict__ d_output,
const unsigned int len
);
template <typename T>
__global__ void Kernels::max_reduce(
const T* __restrict__ d_vector,
T* __restrict__ d_output,
const unsigned int len
) {
__shared__ float shared_max[BLOCK_SIZE];
__shared__ T shared_max[BLOCK_SIZE];
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
shared_max[threadIdx.x] = d_vector[i];
} else {
shared_max[threadIdx.x] = -INFINITY;
}
}
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s) {
// TODO: separate implementation for __half
shared_max[threadIdx.x] = fmaxf(shared_max[threadIdx.x], shared_max[threadIdx.x + s]);
}
__syncthreads();
@@ -182,18 +281,25 @@ __global__ void Kernels::max_reduce(
}
}
__global__ void Kernels::sum_reduce(
template __global__ void Kernels::sum_reduce<float>(
const float* __restrict__ d_vector,
float* __restrict__ d_output,
const unsigned int len
);
template <typename T>
__global__ void Kernels::sum_reduce(
const T* __restrict__ d_vector,
T* __restrict__ d_output,
const unsigned int len
) {
__shared__ float partial_sum[BLOCK_SIZE];
__shared__ T partial_sum[BLOCK_SIZE];
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
partial_sum[threadIdx.x] = d_vector[i];
} else {
partial_sum[threadIdx.x] = 0.0f;
partial_sum[threadIdx.x] = static_cast<T>(0);
}
__syncthreads();
@@ -208,4 +314,4 @@ __global__ void Kernels::sum_reduce(
if (threadIdx.x == 0) {
d_output[blockIdx.x] = partial_sum[0];
}
}
}

View File

@@ -3,7 +3,7 @@
using namespace CUDANet;
__global__ void Kernels::max_pool(
template __global__ void Kernels::max_pool<float>(
const float* __restrict__ d_input,
float* __restrict__ d_output,
const Shape input_shape,
@@ -11,6 +11,17 @@ __global__ void Kernels::max_pool(
const Shape pool_shape,
const Shape stride_shape,
const Shape padding_shape
);
template <typename T>
__global__ void Kernels::max_pool(
const T* __restrict__ d_input,
T* __restrict__ d_output,
const Shape input_shape,
const Shape output_shape,
const Shape pool_shape,
const Shape stride_shape,
const Shape padding_shape
) {
int j = blockDim.x * blockIdx.x + threadIdx.x;
int i = blockDim.y * blockIdx.y + threadIdx.y;
@@ -20,7 +31,7 @@ __global__ void Kernels::max_pool(
return;
}
float max = 0.0f;
T max = static_cast<T>(0);
for (int k = 0; k < pool_shape[0]; k++) {
for (int l = 0; l < pool_shape[1]; l++) {
@@ -43,7 +54,7 @@ __global__ void Kernels::max_pool(
max;
}
__global__ void Kernels::avg_pool(
template __global__ void Kernels::avg_pool<float>(
const float* __restrict__ d_input,
float* __restrict__ d_output,
const Shape input_shape,
@@ -51,6 +62,17 @@ __global__ void Kernels::avg_pool(
const Shape pool_shape,
const Shape stride_shape,
const Shape padding_shape
);
template <typename T>
__global__ void Kernels::avg_pool(
const T* __restrict__ d_input,
T* __restrict__ d_output,
const Shape input_shape,
const Shape output_shape,
const Shape pool_shape,
const Shape stride_shape,
const Shape padding_shape
) {
int j = blockDim.x * blockIdx.x + threadIdx.x;
int i = blockDim.y * blockIdx.y + threadIdx.y;
@@ -60,7 +82,7 @@ __global__ void Kernels::avg_pool(
return;
}
float sum = 0.0f;
T sum = static_cast<T>(0);
for (int k = 0; k < pool_shape[0]; k++) {
for (int l = 0; l < pool_shape[1]; l++) {

View File

@@ -7,24 +7,70 @@
using namespace CUDANet::Backends;
void CUDA::relu(Tensor& tensor) {
switch (tensor.get_dtype()) {
case DType::FLOAT32:
relu_impl<float>(tensor);
break;
default:
throw std::runtime_error("Unsupported dtype");
break;
}
}
template void CUDA::relu_impl<float>(Tensor& tensor);
template <typename T>
void CUDA::relu_impl(Tensor& tensor) {
int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
Kernels::relu<<<gridSize, BLOCK_SIZE>>>(
tensor.data<float>(), tensor.data<float>(), tensor.numel()
tensor.data<T>(), tensor.data<T>(), tensor.numel()
);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
}
void CUDA::sigmoid(Tensor& tensor) {
void CUDA::sigmoid(CUDANet::Tensor& tensor) {
switch (tensor.get_dtype()) {
case DType::FLOAT32:
sigmoid_impl<float>(tensor);
break;
default:
throw std::runtime_error("Unsupported dtype");
break;
}
}
template void CUDA::sigmoid_impl<float>(Tensor& tensor);
template <typename T>
void CUDA::sigmoid_impl(CUDANet::Tensor& tensor) {
int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
Kernels::sigmoid<<<gridSize, BLOCK_SIZE>>>(
tensor.data<float>(), tensor.data<float>(), tensor.numel()
tensor.data<T>(), tensor.data<T>(), tensor.numel()
);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
}
void CUDA::softmax(Tensor& tensor, Tensor& temp_max, Tensor& temp_sum) {
switch (tensor.get_dtype()) {
case DType::FLOAT32:
softmax_impl<float>(tensor, temp_max, temp_sum);
break;
default:
throw std::runtime_error("Unsupported dtype");
break;
}
}
template void
CUDA::softmax_impl<float>(Tensor& tensor, Tensor& temp_max, Tensor& temp_sum);
template <typename T>
void CUDA::softmax_impl(Tensor& tensor, Tensor& temp_max, Tensor& temp_sum) {
int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
// Find max value
@@ -32,14 +78,13 @@ void CUDA::softmax(Tensor& tensor, Tensor& temp_max, Tensor& temp_sum) {
// Subtract max value to improve numerical stability
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
tensor.data<float>(), tensor.data<float>(), temp_max.data<float>(),
tensor.numel()
tensor.data<T>(), tensor.data<T>(), temp_max.data<T>(), tensor.numel()
);
CUDA_CHECK(cudaGetLastError());
// Compute exponentials
Kernels::vec_exp<<<gridSize, BLOCK_SIZE>>>(
tensor.data<float>(), tensor.data<float>(), tensor.numel()
tensor.data<T>(), tensor.data<T>(), tensor.numel()
);
CUDA_CHECK(cudaGetLastError());
@@ -47,8 +92,7 @@ void CUDA::softmax(Tensor& tensor, Tensor& temp_max, Tensor& temp_sum) {
sum(tensor, temp_sum);
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
tensor.data<float>(), tensor.data<float>(), temp_sum.data<float>(),
tensor.numel()
tensor.data<T>(), tensor.data<T>(), temp_sum.data<T>(), tensor.numel()
);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
@@ -61,20 +105,50 @@ CUDANet::Tensor& CUDA::dense(
CUDANet::Tensor& output,
const size_t input_size,
const size_t output_size
) {
switch (input.get_dtype()) {
case DType::FLOAT32:
return dense_impl<float>(
weights, biases, input, output, input_size, output_size
);
break;
default:
throw std::runtime_error("Unsupported dtype");
break;
}
}
template CUDANet::Tensor& CUDA::dense_impl<float>(
const CUDANet::Tensor& weights,
const CUDANet::Tensor& biases,
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
const size_t input_size,
const size_t output_size
);
template <typename T>
CUDANet::Tensor& CUDA::dense_impl(
const CUDANet::Tensor& weights,
const CUDANet::Tensor& biases,
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
const size_t input_size,
const size_t output_size
) {
auto forwardGridSize =
(std::max(input_size, output_size) + BLOCK_SIZE - 1) / BLOCK_SIZE;
auto biasGridSize = (output_size + BLOCK_SIZE - 1) / BLOCK_SIZE;
Kernels::mat_vec_mul<<<forwardGridSize, BLOCK_SIZE>>>(
weights.data<float>(), input.data<float>(), output.data<float>(),
input_size, output_size
weights.data<T>(), input.data<T>(), output.data<T>(), input_size,
output_size
);
CUDA_CHECK(cudaGetLastError());
Kernels::vec_vec_add<<<biasGridSize, BLOCK_SIZE>>>(
biases.data<float>(), output.data<float>(), output.data<float>(),
output_size
biases.data<T>(), output.data<T>(), output.data<T>(), output_size
);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
@@ -92,6 +166,44 @@ CUDANet::Tensor& CUDA::conv2d(
const CUDANet::Shape kernel_shape,
const CUDANet::Shape stride_shape,
const CUDANet::Shape out_shape
) {
switch (input.get_dtype()) {
case DType::FLOAT32:
return conv2d_impl<float>(
weights, biases, input, output, in_shape, padding_shape,
kernel_shape, stride_shape, out_shape
);
break;
default:
throw std::runtime_error("Unsupported dtype");
break;
}
}
template CUDANet::Tensor& CUDA::conv2d_impl<float>(
const CUDANet::Tensor& weights,
const CUDANet::Tensor& biases,
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
const CUDANet::Shape in_shape,
const CUDANet::Shape padding_shape,
const CUDANet::Shape kernel_shape,
const CUDANet::Shape stride_shape,
const CUDANet::Shape out_shape
);
template <typename T>
CUDANet::Tensor& CUDA::conv2d_impl(
const CUDANet::Tensor& weights,
const CUDANet::Tensor& biases,
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
const CUDANet::Shape in_shape,
const CUDANet::Shape padding_shape,
const CUDANet::Shape kernel_shape,
const CUDANet::Shape stride_shape,
const CUDANet::Shape out_shape
) {
dim3 block(8, 8, 8);
dim3 grid(
@@ -101,9 +213,8 @@ CUDANet::Tensor& CUDA::conv2d(
);
Kernels::convolution<<<grid, block>>>(
input.data<float>(), weights.data<float>(), biases.data<float>(),
output.data<float>(), in_shape, padding_shape, kernel_shape,
stride_shape, out_shape
input.data<T>(), weights.data<T>(), biases.data<T>(), output.data<T>(),
in_shape, padding_shape, kernel_shape, stride_shape, out_shape
);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
@@ -113,12 +224,46 @@ CUDANet::Tensor& CUDA::conv2d(
CUDANet::Tensor& CUDA::max_pool2d(
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::Shape output_shape
CUDANet::Tensor& output,
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::Shape output_shape
) {
switch (input.get_dtype()) {
case DType::FLOAT32:
return max_pool2d_impl<float>(
input, output, input_shape, pool_shape, stride_shape,
padding_shape, output_shape
);
break;
default:
throw std::runtime_error("Unsupported dtype");
break;
}
}
template CUDANet::Tensor& CUDA::max_pool2d_impl<float>(
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::Shape output_shape
);
template <typename T>
CUDANet::Tensor& CUDA::max_pool2d_impl(
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::Shape output_shape
) {
dim3 block(8, 8, 8);
dim3 grid(
@@ -128,8 +273,8 @@ CUDANet::Tensor& CUDA::max_pool2d(
);
Kernels::max_pool<<<grid, block>>>(
input.data<float>(), output.data<float>(), input_shape, output_shape, pool_shape,
stride_shape, padding_shape
input.data<T>(), output.data<T>(), input_shape, output_shape,
pool_shape, stride_shape, padding_shape
);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
@@ -139,12 +284,46 @@ CUDANet::Tensor& CUDA::max_pool2d(
CUDANet::Tensor& CUDA::avg_pool2d(
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::Shape output_shape
CUDANet::Tensor& output,
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::Shape output_shape
) {
switch (input.get_dtype()) {
case DType::FLOAT32:
return avg_pool2d_impl<float>(
input, output, input_shape, pool_shape, stride_shape,
padding_shape, output_shape
);
break;
default:
throw std::runtime_error("Unsupported dtype");
break;
}
}
template CUDANet::Tensor& CUDA::avg_pool2d_impl<float>(
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::Shape output_shape
);
template <typename T>
CUDANet::Tensor& CUDA::avg_pool2d_impl(
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
CUDANet::Shape input_shape,
CUDANet::Shape pool_shape,
CUDANet::Shape stride_shape,
CUDANet::Shape padding_shape,
CUDANet::Shape output_shape
) {
dim3 block(8, 8, 8);
dim3 grid(
@@ -154,8 +333,8 @@ CUDANet::Tensor& CUDA::avg_pool2d(
);
Kernels::avg_pool<<<grid, block>>>(
input.data<float>(), output.data<float>(), input_shape, output_shape, pool_shape,
stride_shape, padding_shape
input.data<T>(), output.data<T>(), input_shape, output_shape,
pool_shape, stride_shape, padding_shape
);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
@@ -165,48 +344,84 @@ CUDANet::Tensor& CUDA::avg_pool2d(
CUDANet::Tensor& CUDA::batch_norm(
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
CUDANet::Shape input_shape,
CUDANet::Tensor& weights,
CUDANet::Tensor& biases,
CUDANet::Tensor& running_mean,
CUDANet::Tensor& running_var,
CUDANet::Tensor& epsilon
CUDANet::Tensor& output,
CUDANet::Shape input_shape,
CUDANet::Tensor& weights,
CUDANet::Tensor& biases,
CUDANet::Tensor& running_mean,
CUDANet::Tensor& running_var,
CUDANet::Tensor& epsilon
) {
switch (input.get_dtype()) {
case DType::FLOAT32:
return batch_norm_impl<float>(
input, output, input_shape, weights, biases, running_mean,
running_var, epsilon
);
break;
default:
throw std::runtime_error("Unsupported dtype");
break;
}
}
template CUDANet::Tensor& CUDA::batch_norm_impl<float>(
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
CUDANet::Shape input_shape,
CUDANet::Tensor& weights,
CUDANet::Tensor& biases,
CUDANet::Tensor& running_mean,
CUDANet::Tensor& running_var,
CUDANet::Tensor& epsilon
);
template <typename T>
CUDANet::Tensor& CUDA::batch_norm_impl(
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
CUDANet::Shape input_shape,
CUDANet::Tensor& weights,
CUDANet::Tensor& biases,
CUDANet::Tensor& running_mean,
CUDANet::Tensor& running_var,
CUDANet::Tensor& epsilon
) {
auto gridSize =
(input_shape[0] * input_shape[1] + BLOCK_SIZE - 1) / BLOCK_SIZE;
for (int i = 0; i < input_shape[2]; i++) {
// Subtract mean from input
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
input.data<float>() + i * input_shape[0] * input_shape[1],
output.data<float>() + i * input_shape[0] * input_shape[1],
&running_mean.data<float>()[i], input_shape[0] * input_shape[1]
input.data<T>() + i * input_shape[0] * input_shape[1],
output.data<T>() + i * input_shape[0] * input_shape[1],
&running_mean.data<T>()[i], input_shape[0] * input_shape[1]
);
CUDA_CHECK(cudaGetLastError());
// Divide by sqrt(running_var + epsilon)
Kernels::vec_scale<<<gridSize, BLOCK_SIZE>>>(
output.data<float>() + i * input_shape[0] * input_shape[1],
output.data<float>() + i * input_shape[0] * input_shape[1],
&running_var.data<float>()[i], epsilon.data<float>(), input_shape[0] * input_shape[1]
output.data<T>() + i * input_shape[0] * input_shape[1],
output.data<T>() + i * input_shape[0] * input_shape[1],
&running_var.data<T>()[i], epsilon.data<T>(),
input_shape[0] * input_shape[1]
);
CUDA_CHECK(cudaGetLastError());
// Multiply by weights
Kernels::vec_scalar_mul<<<gridSize, BLOCK_SIZE>>>(
output.data<float>() + i * input_shape[0] * input_shape[1],
output.data<float>() + i * input_shape[0] * input_shape[1], &weights.data<float>()[i],
input_shape[0] * input_shape[1]
output.data<T>() + i * input_shape[0] * input_shape[1],
output.data<T>() + i * input_shape[0] * input_shape[1],
&weights.data<T>()[i], input_shape[0] * input_shape[1]
);
CUDA_CHECK(cudaGetLastError());
// Add biases
Kernels::vec_scalar_add<<<gridSize, BLOCK_SIZE>>>(
output.data<float>() + i * input_shape[0] * input_shape[1],
output.data<float>() + i * input_shape[0] * input_shape[1], &biases.data<float>()[i],
input_shape[0] * input_shape[1]
output.data<T>() + i * input_shape[0] * input_shape[1],
output.data<T>() + i * input_shape[0] * input_shape[1],
&biases.data<T>()[i], input_shape[0] * input_shape[1]
);
CUDA_CHECK(cudaGetLastError());
}
@@ -218,14 +433,39 @@ CUDANet::Tensor& CUDA::concat(
CUDANet::Tensor& input_a,
CUDANet::Tensor& input_b,
CUDANet::Tensor& output
) {
switch (input_a.get_dtype()) {
case DType::FLOAT32:
return concat_impl<float>(
input_a, input_b, output
);
break;
default:
throw std::runtime_error("Unsupported dtype");
break;
}
}
template CUDANet::Tensor& CUDA::concat_impl<float>(
CUDANet::Tensor& input_a,
CUDANet::Tensor& input_b,
CUDANet::Tensor& output
);
template <typename T>
CUDANet::Tensor& CUDA::concat_impl(
CUDANet::Tensor& input_a,
CUDANet::Tensor& input_b,
CUDANet::Tensor& output
) {
CUDA_CHECK(cudaMemcpy(
output.data<float>(), input_a.data<float>(), input_a.size(),
output.data<T>(), input_a.data<T>(), input_a.size(),
cudaMemcpyDeviceToDevice
));
CUDA_CHECK(cudaMemcpy(
output.data<float>() + input_a.numel(), input_b.data<float>(), input_b.size(),
output.data<T>() + input_a.numel(), input_b.data<T>(), input_b.size(),
cudaMemcpyDeviceToDevice
));
@@ -239,11 +479,36 @@ CUDANet::Tensor& CUDA::add(
CUDANet::Tensor& input_a,
CUDANet::Tensor& input_b,
CUDANet::Tensor& output
) {
switch (input_a.get_dtype()) {
case DType::FLOAT32:
return add_impl<float>(
input_a, input_b, output
);
break;
default:
throw std::runtime_error("Unsupported dtype");
break;
}
}
template CUDANet::Tensor& CUDA::add_impl<float>(
CUDANet::Tensor& input_a,
CUDANet::Tensor& input_b,
CUDANet::Tensor& output
);
template <typename T>
CUDANet::Tensor& CUDA::add_impl(
CUDANet::Tensor& input_a,
CUDANet::Tensor& input_b,
CUDANet::Tensor& output
) {
auto gridSize = (input_a.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
Kernels::vec_vec_add<<<gridSize, BLOCK_SIZE>>>(
input_a.data<float>(), input_b.data<float>(), output.data<float>(), input_a.numel()
input_a.data<T>(), input_b.data<T>(), output.data<T>(), input_a.numel()
);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());

View File

@@ -7,11 +7,26 @@
using namespace CUDANet::Backends;
void CUDA::print(const CUDANet::Tensor &input) {
switch (input.get_dtype()) {
case DType::FLOAT32:
print_impl<float>(input);
break;
default:
throw std::runtime_error("Unsupported dtype");
break;
}
}
template void CUDA::print_impl<float> (const CUDANet::Tensor &input);
template <typename T>
void CUDA::print_impl(const CUDANet::Tensor &input) {
auto length = input.numel();
std::vector<float> h_vec(input.numel());
std::vector<T> h_vec(input.numel());
CUDA_CHECK(cudaMemcpy(
h_vec.data(), input.data<float>(), sizeof(float) * length, cudaMemcpyDeviceToHost
h_vec.data(), input.data<T>(), sizeof(T) * length, cudaMemcpyDeviceToHost
));
for (int i = 0; i < length; ++i) {
@@ -26,27 +41,71 @@ void CUDA::zero(CUDANet::Tensor &input) {
}
void CUDA::fill(CUDANet::Tensor &input, int value) {
CUDA_CHECK(cudaMemset(input.data<float>(), value, sizeof(float) * input.numel()));
switch (input.get_dtype()) {
case DType::FLOAT32:
fill_impl<float>(input, value);
break;
default:
throw std::runtime_error("Unsupported dtype");
break;
}
}
template void CUDA::fill_impl<float>(CUDANet::Tensor &input, int value);
template <typename T>
void CUDA::fill_impl(CUDANet::Tensor &input, int value) {
CUDA_CHECK(cudaMemset(input.data<T>(), value, sizeof(T) * input.numel()));
}
void CUDA::copy_to_device(CUDANet::Tensor &tensor, void *data, size_t size) {
CUDA_CHECK(cudaMemcpy(tensor.data<float>(), data, size, cudaMemcpyHostToDevice));
switch (tensor.get_dtype()) {
case DType::FLOAT32:
copy_to_device_impl<float>(tensor, data, size);
break;
default:
throw std::runtime_error("Unsupported dtype");
break;
}
}
template void CUDA::copy_to_device_impl<float>(CUDANet::Tensor &tensor, void *data, size_t size);
template <typename T>
void CUDA::copy_to_device_impl(CUDANet::Tensor &tensor, void *data, size_t size) {
CUDA_CHECK(cudaMemcpy(tensor.data<T>(), data, size, cudaMemcpyHostToDevice));
}
void CUDA::sum(const CUDANet::Tensor &input, CUDANet::Tensor &sum) {
switch (input.get_dtype()) {
case DType::FLOAT32:
sum_impl<float>(input, sum);
break;
default:
throw std::runtime_error("Unsupported dtype");
break;
}
}
template void CUDA::sum_impl<float>(const CUDANet::Tensor &input, CUDANet::Tensor &sum);
template <typename T>
void CUDA::sum_impl(const CUDANet::Tensor &input, CUDANet::Tensor &sum) {
auto length = input.numel();
const int gridSize = ( + BLOCK_SIZE - 1) / BLOCK_SIZE;
CUDANet::Kernels::sum_reduce<<<gridSize, BLOCK_SIZE>>>(
input.data<float>(), sum.data<float>(), length
input.data<T>(), sum.data<T>(), length
);
CUDA_CHECK(cudaGetLastError());
int remaining = gridSize;
while (remaining > 1) {
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
CUDANet::Kernels::sum_reduce<<<blocks_needed, BLOCK_SIZE>>>(sum.data<float>(), sum.data<float>(), remaining);
CUDANet::Kernels::sum_reduce<<<blocks_needed, BLOCK_SIZE>>>(sum.data<T>(), sum.data<T>(), remaining);
CUDA_CHECK(cudaGetLastError());
remaining = blocks_needed;
@@ -54,17 +113,32 @@ void CUDA::sum(const CUDANet::Tensor &input, CUDANet::Tensor &sum) {
}
void CUDA::max(const CUDANet::Tensor &input, CUDANet::Tensor &max) {
switch (input.get_dtype()) {
case DType::FLOAT32:
max_impl<float>(input, max);
break;
default:
throw std::runtime_error("Unsupported dtype");
break;
}
}
template void CUDA::max_impl<float>(const CUDANet::Tensor &input, CUDANet::Tensor &max);
template <typename T>
void CUDA::max_impl(const CUDANet::Tensor &input, CUDANet::Tensor &max) {
auto length = input.numel();
const int grid_size = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
Kernels::max_reduce<<<grid_size, BLOCK_SIZE>>>(input.data<float>(), max.data<float>(), length);
Kernels::max_reduce<<<grid_size, BLOCK_SIZE>>>(input.data<T>(), max.data<T>(), length);
CUDA_CHECK(cudaGetLastError());
int remaining = grid_size;
while (remaining > 1) {
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
CUDANet::Kernels::max_reduce<<<blocks_needed, BLOCK_SIZE>>>(max.data<float>(), max.data<float>(), remaining);
CUDANet::Kernels::max_reduce<<<blocks_needed, BLOCK_SIZE>>>(max.data<T>(), max.data<T>(), remaining);
CUDA_CHECK(cudaGetLastError());
remaining = blocks_needed;

View File

@@ -56,6 +56,7 @@ size_t Dense::output_size() {
return out_shape[0];
};
// TODO: Use dtype
void Dense::set_weights(void* input) {
weights.set_data<float>(static_cast<float*>(input));
}

View File

@@ -80,6 +80,10 @@ Tensor::~Tensor() {
}
}
DType Tensor::get_dtype() const {
return dtype;
}
size_t Tensor::numel() const {
return total_elms;
}