mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-22 14:24:22 +00:00
Compare commits
4 Commits
13d3d38b68
...
7e27c87673
| Author | SHA1 | Date | |
|---|---|---|---|
| 7e27c87673 | |||
| e79667671a | |||
| c855ae89ec | |||
| 9ff214d759 |
@@ -43,6 +43,11 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
|||||||
add_library(${PROJECT_NAME} STATIC ${LIBRARY_SOURCES})
|
add_library(${PROJECT_NAME} STATIC ${LIBRARY_SOURCES})
|
||||||
|
|
||||||
if(USE_CUDA)
|
if(USE_CUDA)
|
||||||
|
# Enable relocatable device code for proper template instantiation across translation units
|
||||||
|
set_target_properties(${PROJECT_NAME} PROPERTIES
|
||||||
|
CUDA_SEPARABLE_COMPILATION ON
|
||||||
|
CUDA_RUNTIME_LIBRARY Shared
|
||||||
|
)
|
||||||
target_link_libraries(${PROJECT_NAME} CUDA::cudart)
|
target_link_libraries(${PROJECT_NAME} CUDA::cudart)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|||||||
@@ -8,9 +8,10 @@
|
|||||||
|
|
||||||
namespace CUDANet {
|
namespace CUDANet {
|
||||||
|
|
||||||
// Forward declaration
|
// Forward declarations
|
||||||
class Tensor;
|
|
||||||
class Backend;
|
class Backend;
|
||||||
|
class Tensor;
|
||||||
|
enum class DType;
|
||||||
|
|
||||||
enum BackendType { CUDA_BACKEND, CPU_BACKEND };
|
enum BackendType { CUDA_BACKEND, CPU_BACKEND };
|
||||||
|
|
||||||
@@ -28,6 +29,7 @@ class Backend {
|
|||||||
std::optional<DType> default_dtype;
|
std::optional<DType> default_dtype;
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
// Dtypes
|
||||||
virtual bool supports_dtype(DType dtype) const = 0;
|
virtual bool supports_dtype(DType dtype) const = 0;
|
||||||
virtual void set_default_dtype(DType dtype) = 0;
|
virtual void set_default_dtype(DType dtype) = 0;
|
||||||
virtual DType get_default_dtype() const = 0;
|
virtual DType get_default_dtype() const = 0;
|
||||||
|
|||||||
@@ -8,53 +8,60 @@
|
|||||||
|
|
||||||
#ifndef BLOCK_SIZE
|
#ifndef BLOCK_SIZE
|
||||||
#define BLOCK_SIZE 128
|
#define BLOCK_SIZE 128
|
||||||
#endif // BLOCK_SIZE
|
#endif // BLOCK_SIZE
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief CUDA error checking macro
|
* @brief CUDA error checking macro
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
#define CUDA_CHECK(call) \
|
#define CUDA_CHECK(call) \
|
||||||
do { \
|
do { \
|
||||||
cudaError_t result = call; \
|
cudaError_t result = call; \
|
||||||
if (result != cudaSuccess) { \
|
if (result != cudaSuccess) { \
|
||||||
fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", \
|
fprintf( \
|
||||||
__FILE__, __LINE__, static_cast<unsigned int>(result), \
|
stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", __FILE__, \
|
||||||
cudaGetErrorString(result), #call); \
|
__LINE__, static_cast<unsigned int>(result), \
|
||||||
exit(EXIT_FAILURE); \
|
cudaGetErrorString(result), #call \
|
||||||
} \
|
); \
|
||||||
} while (0)
|
exit(EXIT_FAILURE); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
namespace CUDANet::Backends {
|
namespace CUDANet::Backends {
|
||||||
|
|
||||||
|
template <DType dtype>
|
||||||
|
struct cuda_dtype_map;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct cuda_dtype_map<DType::FLOAT32> {
|
||||||
|
using type = float;
|
||||||
|
};
|
||||||
|
|
||||||
class CUDA : public Backend {
|
class CUDA : public Backend {
|
||||||
private:
|
|
||||||
int device_id;
|
|
||||||
std::set<DType> supported_dtypes;
|
|
||||||
public:
|
public:
|
||||||
CUDA(const BackendConfig& config);
|
CUDA(const BackendConfig& config);
|
||||||
|
|
||||||
bool supports_dtype(DType dtype) const override;
|
bool supports_dtype(DType dtype) const override;
|
||||||
void set_default_dtype(DType dtype) override;
|
void set_default_dtype(DType dtype) override;
|
||||||
DType get_default_dtype() const override;
|
DType get_default_dtype() const override;
|
||||||
|
|
||||||
static bool is_cuda_available();
|
static bool is_cuda_available();
|
||||||
void initialize();
|
void initialize();
|
||||||
|
|
||||||
// Memory management
|
// Memory management
|
||||||
void* allocate(size_t bytes) override;
|
void* allocate(size_t bytes) override;
|
||||||
void deallocate(void* ptr) override;
|
void deallocate(void* ptr) override;
|
||||||
|
|
||||||
// Tensor ops
|
// Tensor ops dispatchers
|
||||||
void print(const CUDANet::Tensor& input) override;
|
void print(const CUDANet::Tensor& input) override;
|
||||||
void zero(CUDANet::Tensor& input) override;
|
void zero(CUDANet::Tensor& input) override;
|
||||||
void fill(CUDANet::Tensor &input, int value) override;
|
void fill(CUDANet::Tensor& input, int value) override;
|
||||||
void
|
void
|
||||||
copy_to_device(CUDANet::Tensor& tensor, void* data, size_t size) override;
|
copy_to_device(CUDANet::Tensor& tensor, void* data, size_t size) override;
|
||||||
void sum(const CUDANet::Tensor& input, CUDANet::Tensor& sum) override;
|
void sum(const CUDANet::Tensor& input, CUDANet::Tensor& sum) override;
|
||||||
void max(const CUDANet::Tensor& input, CUDANet::Tensor& max) override;
|
void max(const CUDANet::Tensor& input, CUDANet::Tensor& max) override;
|
||||||
|
|
||||||
// Layer ops
|
// Layer ops dispatchers
|
||||||
void relu(CUDANet::Tensor& tensor) override;
|
void relu(CUDANet::Tensor& tensor) override;
|
||||||
void sigmoid(CUDANet::Tensor& tensor) override;
|
void sigmoid(CUDANet::Tensor& tensor) override;
|
||||||
void softmax(
|
void softmax(
|
||||||
@@ -67,7 +74,7 @@ class CUDA : public Backend {
|
|||||||
const CUDANet::Tensor& weights,
|
const CUDANet::Tensor& weights,
|
||||||
const CUDANet::Tensor& biases,
|
const CUDANet::Tensor& biases,
|
||||||
const CUDANet::Tensor& input,
|
const CUDANet::Tensor& input,
|
||||||
CUDANet::Tensor& output,
|
CUDANet::Tensor& output,
|
||||||
const size_t input_size,
|
const size_t input_size,
|
||||||
const size_t output_size
|
const size_t output_size
|
||||||
) override;
|
) override;
|
||||||
@@ -76,43 +83,43 @@ class CUDA : public Backend {
|
|||||||
const CUDANet::Tensor& weights,
|
const CUDANet::Tensor& weights,
|
||||||
const CUDANet::Tensor& biases,
|
const CUDANet::Tensor& biases,
|
||||||
const CUDANet::Tensor& input,
|
const CUDANet::Tensor& input,
|
||||||
CUDANet::Tensor& output,
|
CUDANet::Tensor& output,
|
||||||
const CUDANet::Shape in_shape,
|
const CUDANet::Shape in_shape,
|
||||||
const CUDANet::Shape padding_shape,
|
const CUDANet::Shape padding_shape,
|
||||||
const CUDANet::Shape kernel_shape,
|
const CUDANet::Shape kernel_shape,
|
||||||
const CUDANet::Shape stride_shape,
|
const CUDANet::Shape stride_shape,
|
||||||
const CUDANet::Shape out_shape
|
const CUDANet::Shape out_shape
|
||||||
) override;
|
) override;
|
||||||
|
|
||||||
CUDANet::Tensor& max_pool2d(
|
CUDANet::Tensor& max_pool2d(
|
||||||
const CUDANet::Tensor& input,
|
const CUDANet::Tensor& input,
|
||||||
CUDANet::Tensor& output,
|
CUDANet::Tensor& output,
|
||||||
CUDANet::Shape input_shape,
|
CUDANet::Shape input_shape,
|
||||||
CUDANet::Shape pool_shape,
|
CUDANet::Shape pool_shape,
|
||||||
CUDANet::Shape stride_shape,
|
CUDANet::Shape stride_shape,
|
||||||
CUDANet::Shape padding_shape,
|
CUDANet::Shape padding_shape,
|
||||||
CUDANet::Shape output_shape
|
CUDANet::Shape output_shape
|
||||||
) override;
|
) override;
|
||||||
|
|
||||||
CUDANet::Tensor& avg_pool2d(
|
CUDANet::Tensor& avg_pool2d(
|
||||||
const CUDANet::Tensor& input,
|
const CUDANet::Tensor& input,
|
||||||
CUDANet::Tensor& output,
|
CUDANet::Tensor& output,
|
||||||
CUDANet::Shape input_shape,
|
CUDANet::Shape input_shape,
|
||||||
CUDANet::Shape pool_shape,
|
CUDANet::Shape pool_shape,
|
||||||
CUDANet::Shape stride_shape,
|
CUDANet::Shape stride_shape,
|
||||||
CUDANet::Shape padding_shape,
|
CUDANet::Shape padding_shape,
|
||||||
CUDANet::Shape output_shape
|
CUDANet::Shape output_shape
|
||||||
) override;
|
) override;
|
||||||
|
|
||||||
CUDANet::Tensor& batch_norm(
|
CUDANet::Tensor& batch_norm(
|
||||||
const CUDANet::Tensor& input,
|
const CUDANet::Tensor& input,
|
||||||
CUDANet::Tensor& output,
|
CUDANet::Tensor& output,
|
||||||
CUDANet::Shape input_shape,
|
CUDANet::Shape input_shape,
|
||||||
CUDANet::Tensor& weights,
|
CUDANet::Tensor& weights,
|
||||||
CUDANet::Tensor& biases,
|
CUDANet::Tensor& biases,
|
||||||
CUDANet::Tensor& running_mean,
|
CUDANet::Tensor& running_mean,
|
||||||
CUDANet::Tensor& running_var,
|
CUDANet::Tensor& running_var,
|
||||||
CUDANet::Tensor& epsilon
|
CUDANet::Tensor& epsilon
|
||||||
) override;
|
) override;
|
||||||
|
|
||||||
CUDANet::Tensor& concat(
|
CUDANet::Tensor& concat(
|
||||||
@@ -126,6 +133,111 @@ class CUDA : public Backend {
|
|||||||
CUDANet::Tensor& input_b,
|
CUDANet::Tensor& input_b,
|
||||||
CUDANet::Tensor& output
|
CUDANet::Tensor& output
|
||||||
) override;
|
) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int device_id;
|
||||||
|
std::set<DType> supported_dtypes;
|
||||||
|
|
||||||
|
// Tensor ops template impls
|
||||||
|
template <typename T>
|
||||||
|
void print_impl(const CUDANet::Tensor& input);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void fill_impl(CUDANet::Tensor& input, int value);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void copy_to_device_impl(CUDANet::Tensor& tensor, void* data, size_t size);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void sum_impl(const CUDANet::Tensor& input, CUDANet::Tensor& sum);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void max_impl(const CUDANet::Tensor& input, CUDANet::Tensor& max);
|
||||||
|
|
||||||
|
// Layer ops template impls
|
||||||
|
template <typename T>
|
||||||
|
void relu_impl(CUDANet::Tensor& tensor);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void sigmoid_impl(CUDANet::Tensor& tensor);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void softmax_impl(
|
||||||
|
CUDANet::Tensor& tensor,
|
||||||
|
CUDANet::Tensor& temp_max,
|
||||||
|
CUDANet::Tensor& temp_sum
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
CUDANet::Tensor& dense_impl(
|
||||||
|
const CUDANet::Tensor& weights,
|
||||||
|
const CUDANet::Tensor& biases,
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
const size_t input_size,
|
||||||
|
const size_t output_size
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
CUDANet::Tensor& conv2d_impl(
|
||||||
|
const CUDANet::Tensor& weights,
|
||||||
|
const CUDANet::Tensor& biases,
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
const CUDANet::Shape in_shape,
|
||||||
|
const CUDANet::Shape padding_shape,
|
||||||
|
const CUDANet::Shape kernel_shape,
|
||||||
|
const CUDANet::Shape stride_shape,
|
||||||
|
const CUDANet::Shape out_shape
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
CUDANet::Tensor& max_pool2d_impl(
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
CUDANet::Shape input_shape,
|
||||||
|
CUDANet::Shape pool_shape,
|
||||||
|
CUDANet::Shape stride_shape,
|
||||||
|
CUDANet::Shape padding_shape,
|
||||||
|
CUDANet::Shape output_shape
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
CUDANet::Tensor& avg_pool2d_impl(
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
CUDANet::Shape input_shape,
|
||||||
|
CUDANet::Shape pool_shape,
|
||||||
|
CUDANet::Shape stride_shape,
|
||||||
|
CUDANet::Shape padding_shape,
|
||||||
|
CUDANet::Shape output_shape
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
CUDANet::Tensor& batch_norm_impl(
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
CUDANet::Shape input_shape,
|
||||||
|
CUDANet::Tensor& weights,
|
||||||
|
CUDANet::Tensor& biases,
|
||||||
|
CUDANet::Tensor& running_mean,
|
||||||
|
CUDANet::Tensor& running_var,
|
||||||
|
CUDANet::Tensor& epsilon
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
CUDANet::Tensor& concat_impl(
|
||||||
|
CUDANet::Tensor& input_a,
|
||||||
|
CUDANet::Tensor& input_b,
|
||||||
|
CUDANet::Tensor& output
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
CUDANet::Tensor& add_impl(
|
||||||
|
CUDANet::Tensor& input_a,
|
||||||
|
CUDANet::Tensor& input_b,
|
||||||
|
CUDANet::Tensor& output
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace CUDANet::Backend
|
} // namespace CUDANet::Backends
|
||||||
@@ -4,29 +4,18 @@
|
|||||||
|
|
||||||
namespace CUDANet::Kernels {
|
namespace CUDANet::Kernels {
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Sigmoid activation function kernel
|
template <typename T>
|
||||||
*
|
|
||||||
* @param src Pointer to the source array
|
|
||||||
* @param dst Pointer to the destination array
|
|
||||||
* @param len Length of the arrays
|
|
||||||
*/
|
|
||||||
__global__ void sigmoid(
|
__global__ void sigmoid(
|
||||||
const float* __restrict__ src,
|
const T* __restrict__ src,
|
||||||
float* __restrict__ dst,
|
T* __restrict__ dst,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
template <typename T>
|
||||||
* @brief Relu activation function kernel
|
|
||||||
*
|
|
||||||
* @param src Pointer to the source array
|
|
||||||
* @param dst Pointer to the destination array
|
|
||||||
* @param len Length of the arrays
|
|
||||||
*/
|
|
||||||
__global__ void relu(
|
__global__ void relu(
|
||||||
const float* __restrict__ src,
|
const T* __restrict__ src,
|
||||||
float* __restrict__ dst,
|
T* __restrict__ dst,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -5,11 +5,12 @@
|
|||||||
|
|
||||||
namespace CUDANet::Kernels {
|
namespace CUDANet::Kernels {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
__global__ void convolution(
|
__global__ void convolution(
|
||||||
const float* __restrict__ d_input,
|
const T* __restrict__ d_input,
|
||||||
const float* __restrict__ d_kernel,
|
const T* __restrict__ d_kernel,
|
||||||
const float* __restrict__ d_bias,
|
const T* __restrict__ d_bias,
|
||||||
float* __restrict__ d_output,
|
T* __restrict__ d_output,
|
||||||
const Shape input_shape,
|
const Shape input_shape,
|
||||||
const Shape padding_shape,
|
const Shape padding_shape,
|
||||||
const Shape kernel_shape,
|
const Shape kernel_shape,
|
||||||
|
|||||||
@@ -4,188 +4,105 @@
|
|||||||
|
|
||||||
namespace CUDANet::Kernels {
|
namespace CUDANet::Kernels {
|
||||||
|
|
||||||
/**
|
template <typename T>
|
||||||
* @brief Matrix vector multiplication kernel
|
|
||||||
*
|
|
||||||
* @param d_matrix Device pointer to matrix
|
|
||||||
* @param d_vector Device pointer to vector
|
|
||||||
* @param d_output Device pointer to output vector
|
|
||||||
* @param w Width of the matrix
|
|
||||||
* @param h Height of the matrix
|
|
||||||
*/
|
|
||||||
__global__ void mat_vec_mul(
|
__global__ void mat_vec_mul(
|
||||||
const float* __restrict__ d_matrix,
|
const T* __restrict__ d_matrix,
|
||||||
const float* __restrict__ d_vector,
|
const T* __restrict__ d_vector,
|
||||||
float* __restrict__ d_output,
|
T* __restrict__ d_output,
|
||||||
const unsigned int w,
|
const unsigned int w,
|
||||||
const unsigned int h
|
const unsigned int h
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
template <typename T>
|
||||||
* @brief Vector vector addition kernel
|
|
||||||
*
|
|
||||||
* @param d_vector1 Device pointer to first vector
|
|
||||||
* @param d_vector2 Device pointer to second vector
|
|
||||||
* @param d_output Device pointer to output vector
|
|
||||||
* @param w Length of the vectors
|
|
||||||
*/
|
|
||||||
__global__ void vec_vec_add(
|
__global__ void vec_vec_add(
|
||||||
const float* __restrict__ d_vector1,
|
const T* __restrict__ d_vector1,
|
||||||
const float* __restrict__ d_vector2,
|
const T* __restrict__ d_vector2,
|
||||||
float* __restrict__ d_output,
|
T* __restrict__ d_output,
|
||||||
const unsigned int w
|
const unsigned int w
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
template <typename T>
|
||||||
* @brief Vector vector subtraction kernel
|
|
||||||
*
|
|
||||||
* @param d_vector1
|
|
||||||
* @param d_vector2
|
|
||||||
* @param d_output
|
|
||||||
* @param w
|
|
||||||
* @return __global__
|
|
||||||
*/
|
|
||||||
__global__ void vec_vec_sub(
|
__global__ void vec_vec_sub(
|
||||||
const float* __restrict__ d_vector1,
|
const T* __restrict__ d_vector1,
|
||||||
const float* __restrict__ d_vector2,
|
const T* __restrict__ d_vector2,
|
||||||
float* __restrict__ d_output,
|
T* __restrict__ d_output,
|
||||||
const unsigned int w
|
const unsigned int w
|
||||||
);
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
__global__ void vec_vec_mul(
|
__global__ void vec_vec_mul(
|
||||||
const float* __restrict__ d_vector1,
|
const T* __restrict__ d_vector1,
|
||||||
const float* __restrict__ d_vector2,
|
const T* __restrict__ d_vector2,
|
||||||
float* __restrict__ d_output,
|
T* __restrict__ d_output,
|
||||||
const unsigned int w
|
const unsigned int w
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
template <typename T>
|
||||||
* @brief Sub scalar from each element of the vector
|
|
||||||
*
|
|
||||||
* @param d_vector
|
|
||||||
* @param d_scalar
|
|
||||||
* @param d_output
|
|
||||||
* @param w
|
|
||||||
* @return __global__
|
|
||||||
*/
|
|
||||||
__global__ void vec_scalar_sub(
|
__global__ void vec_scalar_sub(
|
||||||
const float* __restrict__ d_src,
|
const T* __restrict__ d_src,
|
||||||
float* __restrict__ d_out,
|
T* __restrict__ d_out,
|
||||||
const float* __restrict__ d_scalar,
|
const T* __restrict__ d_scalar,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
template <typename T>
|
||||||
* @brief Add scalar to each element of the vector
|
|
||||||
*
|
|
||||||
* @param d_src
|
|
||||||
* @param d_out
|
|
||||||
* @param d_scalar
|
|
||||||
* @param len
|
|
||||||
* @return __global__
|
|
||||||
*/
|
|
||||||
__global__ void vec_scalar_add(
|
__global__ void vec_scalar_add(
|
||||||
const float* __restrict__ d_src,
|
const T* __restrict__ d_src,
|
||||||
float* __restrict__ d_out,
|
T* __restrict__ d_out,
|
||||||
const float* __restrict__ d_scalar,
|
const T* __restrict__ d_scalar,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
template <typename T>
|
||||||
* @brief Divide each element of the vector by a scalar
|
|
||||||
*
|
|
||||||
* @param src Pointer to the source array
|
|
||||||
* @param dst Pointer to the destination array
|
|
||||||
* @param len Length of the arrays
|
|
||||||
*/
|
|
||||||
__global__ void vec_scalar_div(
|
__global__ void vec_scalar_div(
|
||||||
const float* __restrict__ d_src,
|
const T* __restrict__ d_src,
|
||||||
float* __restrict__ d_out,
|
T* __restrict__ d_out,
|
||||||
const float* __restrict__ d_scalar,
|
const T* __restrict__ d_scalar,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
template <typename T>
|
||||||
* @brief Multiply each element of the vector by a scalar
|
|
||||||
*
|
|
||||||
* @param d_src
|
|
||||||
* @param d_out
|
|
||||||
* @param d_scalar
|
|
||||||
* @param len
|
|
||||||
* @return __global__
|
|
||||||
*/
|
|
||||||
__global__ void vec_scalar_mul(
|
__global__ void vec_scalar_mul(
|
||||||
const float* __restrict__ d_src,
|
const T* __restrict__ d_src,
|
||||||
float* __restrict__ d_out,
|
T* __restrict__ d_out,
|
||||||
const float* __restrict__ d_scalar,
|
const T* __restrict__ d_scalar,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
template <typename T>
|
||||||
* @brief Exponentiate each element of the vector
|
|
||||||
*
|
|
||||||
* @param src Pointer to the source array
|
|
||||||
* @param dst Pointer to the destination array
|
|
||||||
* @param len Length of the arrays
|
|
||||||
*/
|
|
||||||
__global__ void vec_exp(
|
__global__ void vec_exp(
|
||||||
const float* __restrict__ src,
|
const T* __restrict__ src,
|
||||||
float* __restrict__ dst,
|
T* __restrict__ dst,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
template <typename T>
|
||||||
* @brief Compute the square root of each element of the vector
|
|
||||||
*
|
|
||||||
* @param src Device pointer to source vector
|
|
||||||
* @param dst Device pointer to destination vector
|
|
||||||
* @param len Length of the vector
|
|
||||||
*/
|
|
||||||
__global__ void vec_sqrt(
|
__global__ void vec_sqrt(
|
||||||
const float* __restrict__ src,
|
const T* __restrict__ src,
|
||||||
float* __restrict__ dst,
|
T* __restrict__ dst,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
template <typename T>
|
||||||
* @brief Scales the vector by 1/sqrt(scale + epsilon)
|
|
||||||
*
|
|
||||||
* @param src Device pointer to source vector
|
|
||||||
* @param dst Device pointer to destination vector
|
|
||||||
* @param scale Scale
|
|
||||||
* @param epsilon Epsilon
|
|
||||||
* @param len Length of the vector
|
|
||||||
*/
|
|
||||||
__global__ void vec_scale(
|
__global__ void vec_scale(
|
||||||
const float* __restrict__ src,
|
const T* __restrict__ src,
|
||||||
float* __restrict__ dst,
|
T* __restrict__ dst,
|
||||||
const float* __restrict__ scale,
|
const T* __restrict__ scale,
|
||||||
const float* epsilon,
|
const T* epsilon,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
template <typename T>
|
||||||
* @brief Max reduction kernel
|
|
||||||
*
|
|
||||||
* @param d_vector Device pointer to vector
|
|
||||||
* @param d_output Device pointer to output vector
|
|
||||||
*/
|
|
||||||
__global__ void max_reduce(
|
__global__ void max_reduce(
|
||||||
const float* __restrict__ d_vector,
|
const T* __restrict__ d_vector,
|
||||||
float* __restrict__ d_output,
|
T* __restrict__ d_output,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
template <typename T>
|
||||||
* @brief
|
|
||||||
*
|
|
||||||
* @param d_vector Device pointer to vector
|
|
||||||
* @param d_output Device pointer to output vector
|
|
||||||
* @param len Length of the vector
|
|
||||||
*/
|
|
||||||
__global__ void sum_reduce(
|
__global__ void sum_reduce(
|
||||||
const float* __restrict__ d_vector,
|
const T* __restrict__ d_vector,
|
||||||
float* __restrict__ d_output,
|
T* __restrict__ d_output,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -5,9 +5,10 @@
|
|||||||
|
|
||||||
namespace CUDANet::Kernels {
|
namespace CUDANet::Kernels {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
__global__ void max_pool(
|
__global__ void max_pool(
|
||||||
const float* __restrict__ d_input,
|
const T* __restrict__ d_input,
|
||||||
float* __restrict__ d_output,
|
T* __restrict__ d_output,
|
||||||
const Shape input_shape,
|
const Shape input_shape,
|
||||||
const Shape output_shape,
|
const Shape output_shape,
|
||||||
const Shape pool_shape,
|
const Shape pool_shape,
|
||||||
@@ -15,9 +16,10 @@ __global__ void max_pool(
|
|||||||
const Shape padding_shape
|
const Shape padding_shape
|
||||||
);
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
__global__ void avg_pool(
|
__global__ void avg_pool(
|
||||||
const float* __restrict__ d_input,
|
const T* __restrict__ d_input,
|
||||||
float* __restrict__ d_output,
|
T* __restrict__ d_output,
|
||||||
const Shape input_shape,
|
const Shape input_shape,
|
||||||
const Shape output_shape,
|
const Shape output_shape,
|
||||||
const Shape pool_shape,
|
const Shape pool_shape,
|
||||||
|
|||||||
@@ -15,10 +15,6 @@ class Module {
|
|||||||
|
|
||||||
CUDANet::Shape output_shape();
|
CUDANet::Shape output_shape();
|
||||||
|
|
||||||
size_t input_size();
|
|
||||||
|
|
||||||
size_t output_size();
|
|
||||||
|
|
||||||
void register_layer(const std::string& name, Layer* layer);
|
void register_layer(const std::string& name, Layer* layer);
|
||||||
|
|
||||||
void register_module(Module& module);
|
void register_module(Module& module);
|
||||||
|
|||||||
@@ -16,7 +16,20 @@ enum class DType
|
|||||||
// INT32, // Not implemented yet
|
// INT32, // Not implemented yet
|
||||||
};
|
};
|
||||||
|
|
||||||
size_t dtype_size(DType dtype);
|
size_t dtype_size(DType dtype) {
|
||||||
|
switch (dtype)
|
||||||
|
{
|
||||||
|
case DType::FLOAT32:
|
||||||
|
return 4;
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unknown DType");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class Backend;
|
||||||
|
|
||||||
class Tensor
|
class Tensor
|
||||||
{
|
{
|
||||||
@@ -33,32 +46,19 @@ public:
|
|||||||
|
|
||||||
~Tensor();
|
~Tensor();
|
||||||
|
|
||||||
DType get_dtype();
|
DType get_dtype() const;
|
||||||
|
|
||||||
size_t size() const;
|
size_t size() const;
|
||||||
size_t numel() const;
|
size_t numel() const;
|
||||||
|
|
||||||
template <typename T>
|
void* device_ptr() const;
|
||||||
const T* data() const {
|
void* device_ptr();
|
||||||
return static_cast<T*>(d_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
T* data() {
|
|
||||||
return static_cast<T*>(d_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
void zero();
|
void zero();
|
||||||
|
|
||||||
template <typename T>
|
void fill(int value);
|
||||||
void fill(T value) {
|
|
||||||
backend->fill(*this, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
void set_data(void *data);
|
||||||
void set_data(T *data) {
|
|
||||||
backend->copy_to_device(*this, data, total_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Shape shape;
|
Shape shape;
|
||||||
|
|||||||
@@ -13,24 +13,24 @@ std::unique_ptr<Backend> BackendFactory::create(BackendType backend_type, const
|
|||||||
switch (backend_type)
|
switch (backend_type)
|
||||||
{
|
{
|
||||||
case BackendType::CUDA_BACKEND:
|
case BackendType::CUDA_BACKEND:
|
||||||
|
{
|
||||||
#ifdef USE_CUDA
|
#ifdef USE_CUDA
|
||||||
|
|
||||||
if (!CUDANet::Backends::CUDA::is_cuda_available()) {
|
if (!CUDANet::Backends::CUDA::is_cuda_available()) {
|
||||||
throw std::runtime_error("No CUDA devices found")
|
throw std::runtime_error("No CUDA devices found");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto cuda = std::make_unique<CUDANet::Backends::CUDA>(config);
|
auto cuda = std::make_unique<CUDANet::Backends::CUDA>(config);
|
||||||
cuda.initialize();
|
|
||||||
|
|
||||||
return cuda;
|
return cuda;
|
||||||
|
|
||||||
#else
|
#else
|
||||||
throw std::runtime_error("Library was compiled without CUDA support.");
|
throw std::runtime_error("Library was compiled without CUDA support.");
|
||||||
#endif
|
#endif
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
throw std::runtime_error("Invalid backend");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -65,7 +65,6 @@ CUDANet::DType CUDA::get_default_dtype() const {
|
|||||||
return DType::FLOAT32;
|
return DType::FLOAT32;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void* CUDA::allocate(size_t bytes) {
|
void* CUDA::allocate(size_t bytes) {
|
||||||
void* d_ptr = nullptr;
|
void* d_ptr = nullptr;
|
||||||
CUDA_CHECK(cudaMalloc(&d_ptr, bytes));
|
CUDA_CHECK(cudaMalloc(&d_ptr, bytes));
|
||||||
|
|||||||
@@ -2,10 +2,18 @@
|
|||||||
|
|
||||||
using namespace CUDANet;
|
using namespace CUDANet;
|
||||||
|
|
||||||
__global__ void Kernels::sigmoid(
|
template
|
||||||
|
__global__ void Kernels::sigmoid<float>(
|
||||||
const float* __restrict__ src,
|
const float* __restrict__ src,
|
||||||
float* __restrict__ dst,
|
float* __restrict__ dst,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void Kernels::sigmoid(
|
||||||
|
const T* __restrict__ src,
|
||||||
|
T* __restrict__ dst,
|
||||||
|
const unsigned int len
|
||||||
) {
|
) {
|
||||||
int stride = gridDim.x * blockDim.x;
|
int stride = gridDim.x * blockDim.x;
|
||||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
@@ -15,10 +23,17 @@ __global__ void Kernels::sigmoid(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Kernels::relu(
|
template __global__ void Kernels::relu<float>(
|
||||||
const float* __restrict__ src,
|
const float* __restrict__ src,
|
||||||
float* __restrict__ dst,
|
float* __restrict__ dst,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void Kernels::relu(
|
||||||
|
const T* __restrict__ src,
|
||||||
|
T* __restrict__ dst,
|
||||||
|
const unsigned int len
|
||||||
) {
|
) {
|
||||||
int stride = gridDim.x * blockDim.x;
|
int stride = gridDim.x * blockDim.x;
|
||||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
using namespace CUDANet;
|
using namespace CUDANet;
|
||||||
|
|
||||||
__global__ void Kernels::convolution(
|
template __global__ void Kernels::convolution<float>(
|
||||||
const float* __restrict__ d_input,
|
const float* __restrict__ d_input,
|
||||||
const float* __restrict__ d_kernel,
|
const float* __restrict__ d_kernel,
|
||||||
const float* __restrict__ d_bias,
|
const float* __restrict__ d_bias,
|
||||||
@@ -14,6 +14,19 @@ __global__ void Kernels::convolution(
|
|||||||
const Shape kernel_shape,
|
const Shape kernel_shape,
|
||||||
const Shape stride_shape,
|
const Shape stride_shape,
|
||||||
const Shape output_shape
|
const Shape output_shape
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void Kernels::convolution(
|
||||||
|
const T* __restrict__ d_input,
|
||||||
|
const T* __restrict__ d_kernel,
|
||||||
|
const T* __restrict__ d_bias,
|
||||||
|
T* __restrict__ d_output,
|
||||||
|
const Shape input_shape,
|
||||||
|
const Shape padding_shape,
|
||||||
|
const Shape kernel_shape,
|
||||||
|
const Shape stride_shape,
|
||||||
|
const Shape output_shape
|
||||||
) {
|
) {
|
||||||
int j = blockDim.x * blockIdx.x + threadIdx.x;
|
int j = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
||||||
@@ -23,7 +36,7 @@ __global__ void Kernels::convolution(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
float sum = 0.0f;
|
T sum = static_cast<T>(0);
|
||||||
|
|
||||||
// Iterate over kernel and input matrix
|
// Iterate over kernel and input matrix
|
||||||
for (int c = 0; c < input_shape[2]; c++) {
|
for (int c = 0; c < input_shape[2]; c++) {
|
||||||
|
|||||||
@@ -3,17 +3,26 @@
|
|||||||
|
|
||||||
using namespace CUDANet;
|
using namespace CUDANet;
|
||||||
|
|
||||||
__global__ void Kernels::mat_vec_mul(
|
template __global__ void Kernels::mat_vec_mul<float>(
|
||||||
const float* __restrict__ d_matrix,
|
const float* __restrict__ d_matrix,
|
||||||
const float* __restrict__ d_vector,
|
const float* __restrict__ d_vector,
|
||||||
float* __restrict__ d_output,
|
float* __restrict__ d_output,
|
||||||
const unsigned int w,
|
const unsigned int w,
|
||||||
const unsigned int h
|
const unsigned int h
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void Kernels::mat_vec_mul(
|
||||||
|
const T* __restrict__ d_matrix,
|
||||||
|
const T* __restrict__ d_vector,
|
||||||
|
T* __restrict__ d_output,
|
||||||
|
const unsigned int w,
|
||||||
|
const unsigned int h
|
||||||
) {
|
) {
|
||||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
if (tid < h) {
|
if (tid < h) {
|
||||||
float temp = 0.0f;
|
T temp = static_cast<T>(0);
|
||||||
|
|
||||||
for (unsigned int j = 0; j < w; j++) {
|
for (unsigned int j = 0; j < w; j++) {
|
||||||
temp += d_matrix[tid * w + j] * d_vector[j];
|
temp += d_matrix[tid * w + j] * d_vector[j];
|
||||||
@@ -23,11 +32,19 @@ __global__ void Kernels::mat_vec_mul(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Kernels::vec_vec_add(
|
template __global__ void Kernels::vec_vec_add<float>(
|
||||||
const float* __restrict__ d_vector1,
|
const float* __restrict__ d_vector1,
|
||||||
const float* __restrict__ d_vector2,
|
const float* __restrict__ d_vector2,
|
||||||
float* __restrict__ d_output,
|
float* __restrict__ d_output,
|
||||||
const unsigned int w
|
const unsigned int w
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void Kernels::vec_vec_add(
|
||||||
|
const T* __restrict__ d_vector1,
|
||||||
|
const T* __restrict__ d_vector2,
|
||||||
|
T* __restrict__ d_output,
|
||||||
|
const unsigned int w
|
||||||
) {
|
) {
|
||||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
if (tid >= w) {
|
if (tid >= w) {
|
||||||
@@ -36,11 +53,19 @@ __global__ void Kernels::vec_vec_add(
|
|||||||
d_output[tid] = d_vector1[tid] + d_vector2[tid];
|
d_output[tid] = d_vector1[tid] + d_vector2[tid];
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Kernels::vec_vec_sub(
|
template __global__ void Kernels::vec_vec_sub<float>(
|
||||||
const float* __restrict__ d_vector1,
|
const float* __restrict__ d_vector1,
|
||||||
const float* __restrict__ d_vector2,
|
const float* __restrict__ d_vector2,
|
||||||
float* __restrict__ d_output,
|
float* __restrict__ d_output,
|
||||||
const unsigned int w
|
const unsigned int w
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void Kernels::vec_vec_sub(
|
||||||
|
const T* __restrict__ d_vector1,
|
||||||
|
const T* __restrict__ d_vector2,
|
||||||
|
T* __restrict__ d_output,
|
||||||
|
const unsigned int w
|
||||||
) {
|
) {
|
||||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
if (tid >= w) {
|
if (tid >= w) {
|
||||||
@@ -49,11 +74,19 @@ __global__ void Kernels::vec_vec_sub(
|
|||||||
d_output[tid] = d_vector1[tid] - d_vector2[tid];
|
d_output[tid] = d_vector1[tid] - d_vector2[tid];
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Kernels::vec_vec_mul(
|
template __global__ void Kernels::vec_vec_mul<float>(
|
||||||
const float* __restrict__ d_vector1,
|
const float* __restrict__ d_vector1,
|
||||||
const float* __restrict__ d_vector2,
|
const float* __restrict__ d_vector2,
|
||||||
float* __restrict__ d_output,
|
float* __restrict__ d_output,
|
||||||
const unsigned int w
|
const unsigned int w
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void Kernels::vec_vec_mul(
|
||||||
|
const T* __restrict__ d_vector1,
|
||||||
|
const T* __restrict__ d_vector2,
|
||||||
|
T* __restrict__ d_output,
|
||||||
|
const unsigned int w
|
||||||
) {
|
) {
|
||||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
if (tid >= w) {
|
if (tid >= w) {
|
||||||
@@ -62,11 +95,19 @@ __global__ void Kernels::vec_vec_mul(
|
|||||||
d_output[tid] = d_vector1[tid] * d_vector2[tid];
|
d_output[tid] = d_vector1[tid] * d_vector2[tid];
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Kernels::vec_scalar_sub(
|
template __global__ void Kernels::vec_scalar_sub<float>(
|
||||||
const float* __restrict__ d_src,
|
const float* __restrict__ d_src,
|
||||||
float* __restrict__ d_out,
|
float* __restrict__ d_out,
|
||||||
const float* __restrict__ d_scalar,
|
const float* __restrict__ d_scalar,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void Kernels::vec_scalar_sub(
|
||||||
|
const T* __restrict__ d_src,
|
||||||
|
T* __restrict__ d_out,
|
||||||
|
const T* __restrict__ d_scalar,
|
||||||
|
const unsigned int len
|
||||||
) {
|
) {
|
||||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
if (tid >= len) {
|
if (tid >= len) {
|
||||||
@@ -75,11 +116,19 @@ __global__ void Kernels::vec_scalar_sub(
|
|||||||
d_out[tid] = d_src[tid] - *d_scalar;
|
d_out[tid] = d_src[tid] - *d_scalar;
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Kernels::vec_scalar_add(
|
template __global__ void Kernels::vec_scalar_add<float>(
|
||||||
const float* __restrict__ d_src,
|
const float* __restrict__ d_src,
|
||||||
float* __restrict__ d_out,
|
float* __restrict__ d_out,
|
||||||
const float* __restrict__ d_scalar,
|
const float* __restrict__ d_scalar,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void Kernels::vec_scalar_add(
|
||||||
|
const T* __restrict__ d_src,
|
||||||
|
T* __restrict__ d_out,
|
||||||
|
const T* __restrict__ d_scalar,
|
||||||
|
const unsigned int len
|
||||||
) {
|
) {
|
||||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
if (tid >= len) {
|
if (tid >= len) {
|
||||||
@@ -88,11 +137,19 @@ __global__ void Kernels::vec_scalar_add(
|
|||||||
d_out[tid] = d_src[tid] + *d_scalar;
|
d_out[tid] = d_src[tid] + *d_scalar;
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Kernels::vec_scalar_div(
|
template __global__ void Kernels::vec_scalar_div<float>(
|
||||||
const float* __restrict__ d_src,
|
const float* __restrict__ d_src,
|
||||||
float* __restrict__ d_out,
|
float* __restrict__ d_out,
|
||||||
const float* __restrict__ d_scalar,
|
const float* __restrict__ d_scalar,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void Kernels::vec_scalar_div(
|
||||||
|
const T* __restrict__ d_src,
|
||||||
|
T* __restrict__ d_out,
|
||||||
|
const T* __restrict__ d_scalar,
|
||||||
|
const unsigned int len
|
||||||
) {
|
) {
|
||||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
if (tid >= len) {
|
if (tid >= len) {
|
||||||
@@ -101,11 +158,19 @@ __global__ void Kernels::vec_scalar_div(
|
|||||||
d_out[tid] = d_src[tid] / *d_scalar;
|
d_out[tid] = d_src[tid] / *d_scalar;
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Kernels::vec_scalar_mul(
|
template __global__ void Kernels::vec_scalar_mul<float>(
|
||||||
const float* __restrict__ d_src,
|
const float* __restrict__ d_src,
|
||||||
float* __restrict__ d_out,
|
float* __restrict__ d_out,
|
||||||
const float* __restrict__ d_scalar,
|
const float* __restrict__ d_scalar,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void Kernels::vec_scalar_mul(
|
||||||
|
const T* __restrict__ d_src,
|
||||||
|
T* __restrict__ d_out,
|
||||||
|
const T* __restrict__ d_scalar,
|
||||||
|
const unsigned int len
|
||||||
) {
|
) {
|
||||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
if (tid >= len) {
|
if (tid >= len) {
|
||||||
@@ -114,52 +179,85 @@ __global__ void Kernels::vec_scalar_mul(
|
|||||||
d_out[tid] = d_src[tid] * *d_scalar;
|
d_out[tid] = d_src[tid] * *d_scalar;
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Kernels::vec_exp(
|
template __global__ void Kernels::vec_exp<float>(
|
||||||
const float* __restrict__ src,
|
const float* __restrict__ src,
|
||||||
float* __restrict__ dst,
|
float* __restrict__ dst,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void Kernels::vec_exp(
|
||||||
|
const T* __restrict__ src,
|
||||||
|
T* __restrict__ dst,
|
||||||
|
const unsigned int len
|
||||||
) {
|
) {
|
||||||
int stride = gridDim.x * blockDim.x;
|
int stride = gridDim.x * blockDim.x;
|
||||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
for (int i = tid; i < len; i += stride) {
|
for (int i = tid; i < len; i += stride) {
|
||||||
|
// TODO: separate implementation for __half
|
||||||
dst[i] = expf(src[i]);
|
dst[i] = expf(src[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Kernels::vec_sqrt(
|
template __global__ void Kernels::vec_sqrt<float>(
|
||||||
const float* __restrict__ src,
|
const float* __restrict__ src,
|
||||||
float* __restrict__ dst,
|
float* __restrict__ dst,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void Kernels::vec_sqrt(
|
||||||
|
const T* __restrict__ src,
|
||||||
|
T* __restrict__ dst,
|
||||||
|
const unsigned int len
|
||||||
) {
|
) {
|
||||||
int stride = gridDim.x * blockDim.x;
|
int stride = gridDim.x * blockDim.x;
|
||||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
for (int i = tid; i < len; i += stride) {
|
for (int i = tid; i < len; i += stride) {
|
||||||
|
// TODO: separate implementation for __half
|
||||||
dst[i] = sqrtf(src[i]);
|
dst[i] = sqrtf(src[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Kernels::vec_scale(
|
template __global__ void Kernels::vec_scale<float>(
|
||||||
const float* __restrict__ src,
|
const float* __restrict__ src,
|
||||||
float* __restrict__ dst,
|
float* __restrict__ dst,
|
||||||
const float* __restrict__ scale,
|
const float* __restrict__ scale,
|
||||||
const float* epsilon,
|
const float* epsilon,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void Kernels::vec_scale(
|
||||||
|
const T* __restrict__ src,
|
||||||
|
T* __restrict__ dst,
|
||||||
|
const T* __restrict__ scale,
|
||||||
|
const T* epsilon,
|
||||||
|
const unsigned int len
|
||||||
) {
|
) {
|
||||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
if (idx < len) {
|
if (idx < len) {
|
||||||
|
// TODO: separate implementation for __half
|
||||||
float inv_std = rsqrtf(*scale + *epsilon);
|
float inv_std = rsqrtf(*scale + *epsilon);
|
||||||
dst[idx] = src[idx] * inv_std;
|
dst[idx] = src[idx] * inv_std;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Kernels::max_reduce(
|
template __global__ void Kernels::max_reduce<float>(
|
||||||
const float* __restrict__ d_vector,
|
const float* __restrict__ d_vector,
|
||||||
float* __restrict__ d_output,
|
float* __restrict__ d_output,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void Kernels::max_reduce(
|
||||||
|
const T* __restrict__ d_vector,
|
||||||
|
T* __restrict__ d_output,
|
||||||
|
const unsigned int len
|
||||||
) {
|
) {
|
||||||
__shared__ float shared_max[BLOCK_SIZE];
|
__shared__ T shared_max[BLOCK_SIZE];
|
||||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
if (i < len) {
|
if (i < len) {
|
||||||
@@ -172,6 +270,7 @@ __global__ void Kernels::max_reduce(
|
|||||||
|
|
||||||
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||||
if (threadIdx.x < s) {
|
if (threadIdx.x < s) {
|
||||||
|
// TODO: separate implementation for __half
|
||||||
shared_max[threadIdx.x] = fmaxf(shared_max[threadIdx.x], shared_max[threadIdx.x + s]);
|
shared_max[threadIdx.x] = fmaxf(shared_max[threadIdx.x], shared_max[threadIdx.x + s]);
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@@ -182,18 +281,25 @@ __global__ void Kernels::max_reduce(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Kernels::sum_reduce(
|
template __global__ void Kernels::sum_reduce<float>(
|
||||||
const float* __restrict__ d_vector,
|
const float* __restrict__ d_vector,
|
||||||
float* __restrict__ d_output,
|
float* __restrict__ d_output,
|
||||||
const unsigned int len
|
const unsigned int len
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void Kernels::sum_reduce(
|
||||||
|
const T* __restrict__ d_vector,
|
||||||
|
T* __restrict__ d_output,
|
||||||
|
const unsigned int len
|
||||||
) {
|
) {
|
||||||
__shared__ float partial_sum[BLOCK_SIZE];
|
__shared__ T partial_sum[BLOCK_SIZE];
|
||||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
if (i < len) {
|
if (i < len) {
|
||||||
partial_sum[threadIdx.x] = d_vector[i];
|
partial_sum[threadIdx.x] = d_vector[i];
|
||||||
} else {
|
} else {
|
||||||
partial_sum[threadIdx.x] = 0.0f;
|
partial_sum[threadIdx.x] = static_cast<T>(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
using namespace CUDANet;
|
using namespace CUDANet;
|
||||||
|
|
||||||
__global__ void Kernels::max_pool(
|
template __global__ void Kernels::max_pool<float>(
|
||||||
const float* __restrict__ d_input,
|
const float* __restrict__ d_input,
|
||||||
float* __restrict__ d_output,
|
float* __restrict__ d_output,
|
||||||
const Shape input_shape,
|
const Shape input_shape,
|
||||||
@@ -11,6 +11,17 @@ __global__ void Kernels::max_pool(
|
|||||||
const Shape pool_shape,
|
const Shape pool_shape,
|
||||||
const Shape stride_shape,
|
const Shape stride_shape,
|
||||||
const Shape padding_shape
|
const Shape padding_shape
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void Kernels::max_pool(
|
||||||
|
const T* __restrict__ d_input,
|
||||||
|
T* __restrict__ d_output,
|
||||||
|
const Shape input_shape,
|
||||||
|
const Shape output_shape,
|
||||||
|
const Shape pool_shape,
|
||||||
|
const Shape stride_shape,
|
||||||
|
const Shape padding_shape
|
||||||
) {
|
) {
|
||||||
int j = blockDim.x * blockIdx.x + threadIdx.x;
|
int j = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
||||||
@@ -20,7 +31,7 @@ __global__ void Kernels::max_pool(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
float max = 0.0f;
|
T max = static_cast<T>(0);
|
||||||
|
|
||||||
for (int k = 0; k < pool_shape[0]; k++) {
|
for (int k = 0; k < pool_shape[0]; k++) {
|
||||||
for (int l = 0; l < pool_shape[1]; l++) {
|
for (int l = 0; l < pool_shape[1]; l++) {
|
||||||
@@ -43,7 +54,7 @@ __global__ void Kernels::max_pool(
|
|||||||
max;
|
max;
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Kernels::avg_pool(
|
template __global__ void Kernels::avg_pool<float>(
|
||||||
const float* __restrict__ d_input,
|
const float* __restrict__ d_input,
|
||||||
float* __restrict__ d_output,
|
float* __restrict__ d_output,
|
||||||
const Shape input_shape,
|
const Shape input_shape,
|
||||||
@@ -51,6 +62,17 @@ __global__ void Kernels::avg_pool(
|
|||||||
const Shape pool_shape,
|
const Shape pool_shape,
|
||||||
const Shape stride_shape,
|
const Shape stride_shape,
|
||||||
const Shape padding_shape
|
const Shape padding_shape
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void Kernels::avg_pool(
|
||||||
|
const T* __restrict__ d_input,
|
||||||
|
T* __restrict__ d_output,
|
||||||
|
const Shape input_shape,
|
||||||
|
const Shape output_shape,
|
||||||
|
const Shape pool_shape,
|
||||||
|
const Shape stride_shape,
|
||||||
|
const Shape padding_shape
|
||||||
) {
|
) {
|
||||||
int j = blockDim.x * blockIdx.x + threadIdx.x;
|
int j = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
||||||
@@ -60,7 +82,7 @@ __global__ void Kernels::avg_pool(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
float sum = 0.0f;
|
T sum = static_cast<T>(0);
|
||||||
|
|
||||||
for (int k = 0; k < pool_shape[0]; k++) {
|
for (int k = 0; k < pool_shape[0]; k++) {
|
||||||
for (int l = 0; l < pool_shape[1]; l++) {
|
for (int l = 0; l < pool_shape[1]; l++) {
|
||||||
|
|||||||
@@ -7,24 +7,70 @@
|
|||||||
using namespace CUDANet::Backends;
|
using namespace CUDANet::Backends;
|
||||||
|
|
||||||
void CUDA::relu(Tensor& tensor) {
|
void CUDA::relu(Tensor& tensor) {
|
||||||
|
switch (tensor.get_dtype()) {
|
||||||
|
case DType::FLOAT32:
|
||||||
|
relu_impl<float>(tensor);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unsupported dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template void CUDA::relu_impl<float>(Tensor& tensor);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void CUDA::relu_impl(Tensor& tensor) {
|
||||||
int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
Kernels::relu<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::relu<<<gridSize, BLOCK_SIZE>>>(
|
||||||
tensor.data<float>(), tensor.data<float>(), tensor.numel()
|
static_cast<T*>(tensor.device_ptr()), static_cast<T*>(tensor.device_ptr()), tensor.numel()
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
}
|
}
|
||||||
|
|
||||||
void CUDA::sigmoid(Tensor& tensor) {
|
void CUDA::sigmoid(CUDANet::Tensor& tensor) {
|
||||||
|
switch (tensor.get_dtype()) {
|
||||||
|
case DType::FLOAT32:
|
||||||
|
sigmoid_impl<float>(tensor);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unsupported dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template void CUDA::sigmoid_impl<float>(Tensor& tensor);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void CUDA::sigmoid_impl(CUDANet::Tensor& tensor) {
|
||||||
int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
Kernels::sigmoid<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::sigmoid<<<gridSize, BLOCK_SIZE>>>(
|
||||||
tensor.data<float>(), tensor.data<float>(), tensor.numel()
|
static_cast<T*>(tensor.device_ptr()), static_cast<T*>(tensor.device_ptr()), tensor.numel()
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
}
|
}
|
||||||
|
|
||||||
void CUDA::softmax(Tensor& tensor, Tensor& temp_max, Tensor& temp_sum) {
|
void CUDA::softmax(Tensor& tensor, Tensor& temp_max, Tensor& temp_sum) {
|
||||||
|
switch (tensor.get_dtype()) {
|
||||||
|
case DType::FLOAT32:
|
||||||
|
softmax_impl<float>(tensor, temp_max, temp_sum);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unsupported dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template void
|
||||||
|
CUDA::softmax_impl<float>(Tensor& tensor, Tensor& temp_max, Tensor& temp_sum);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void CUDA::softmax_impl(Tensor& tensor, Tensor& temp_max, Tensor& temp_sum) {
|
||||||
int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
|
|
||||||
// Find max value
|
// Find max value
|
||||||
@@ -32,14 +78,13 @@ void CUDA::softmax(Tensor& tensor, Tensor& temp_max, Tensor& temp_sum) {
|
|||||||
|
|
||||||
// Subtract max value to improve numerical stability
|
// Subtract max value to improve numerical stability
|
||||||
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
|
||||||
tensor.data<float>(), tensor.data<float>(), temp_max.data<float>(),
|
static_cast<T*>(tensor.device_ptr()), static_cast<T*>(tensor.device_ptr()), static_cast<T*>(temp_max.device_ptr()), tensor.numel()
|
||||||
tensor.numel()
|
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
// Compute exponentials
|
// Compute exponentials
|
||||||
Kernels::vec_exp<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::vec_exp<<<gridSize, BLOCK_SIZE>>>(
|
||||||
tensor.data<float>(), tensor.data<float>(), tensor.numel()
|
static_cast<T*>(tensor.device_ptr()), static_cast<T*>(tensor.device_ptr()), tensor.numel()
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
@@ -47,8 +92,7 @@ void CUDA::softmax(Tensor& tensor, Tensor& temp_max, Tensor& temp_sum) {
|
|||||||
sum(tensor, temp_sum);
|
sum(tensor, temp_sum);
|
||||||
|
|
||||||
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
|
||||||
tensor.data<float>(), tensor.data<float>(), temp_sum.data<float>(),
|
static_cast<T*>(tensor.device_ptr()), static_cast<T*>(tensor.device_ptr()), static_cast<T*>(temp_sum.device_ptr()), tensor.numel()
|
||||||
tensor.numel()
|
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
@@ -61,20 +105,50 @@ CUDANet::Tensor& CUDA::dense(
|
|||||||
CUDANet::Tensor& output,
|
CUDANet::Tensor& output,
|
||||||
const size_t input_size,
|
const size_t input_size,
|
||||||
const size_t output_size
|
const size_t output_size
|
||||||
|
) {
|
||||||
|
switch (input.get_dtype()) {
|
||||||
|
case DType::FLOAT32:
|
||||||
|
return dense_impl<float>(
|
||||||
|
weights, biases, input, output, input_size, output_size
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unsupported dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template CUDANet::Tensor& CUDA::dense_impl<float>(
|
||||||
|
const CUDANet::Tensor& weights,
|
||||||
|
const CUDANet::Tensor& biases,
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
const size_t input_size,
|
||||||
|
const size_t output_size
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
CUDANet::Tensor& CUDA::dense_impl(
|
||||||
|
const CUDANet::Tensor& weights,
|
||||||
|
const CUDANet::Tensor& biases,
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
const size_t input_size,
|
||||||
|
const size_t output_size
|
||||||
) {
|
) {
|
||||||
auto forwardGridSize =
|
auto forwardGridSize =
|
||||||
(std::max(input_size, output_size) + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
(std::max(input_size, output_size) + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
auto biasGridSize = (output_size + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
auto biasGridSize = (output_size + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
|
|
||||||
Kernels::mat_vec_mul<<<forwardGridSize, BLOCK_SIZE>>>(
|
Kernels::mat_vec_mul<<<forwardGridSize, BLOCK_SIZE>>>(
|
||||||
weights.data<float>(), input.data<float>(), output.data<float>(),
|
static_cast<T*>(weights.device_ptr()), static_cast<T*>(input.device_ptr()), static_cast<T*>(output.device_ptr()), input_size,
|
||||||
input_size, output_size
|
output_size
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
Kernels::vec_vec_add<<<biasGridSize, BLOCK_SIZE>>>(
|
Kernels::vec_vec_add<<<biasGridSize, BLOCK_SIZE>>>(
|
||||||
biases.data<float>(), output.data<float>(), output.data<float>(),
|
static_cast<T*>(biases.device_ptr()), static_cast<T*>(output.device_ptr()), static_cast<T*>(output.device_ptr()), output_size
|
||||||
output_size
|
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
@@ -92,6 +166,44 @@ CUDANet::Tensor& CUDA::conv2d(
|
|||||||
const CUDANet::Shape kernel_shape,
|
const CUDANet::Shape kernel_shape,
|
||||||
const CUDANet::Shape stride_shape,
|
const CUDANet::Shape stride_shape,
|
||||||
const CUDANet::Shape out_shape
|
const CUDANet::Shape out_shape
|
||||||
|
) {
|
||||||
|
switch (input.get_dtype()) {
|
||||||
|
case DType::FLOAT32:
|
||||||
|
return conv2d_impl<float>(
|
||||||
|
weights, biases, input, output, in_shape, padding_shape,
|
||||||
|
kernel_shape, stride_shape, out_shape
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unsupported dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template CUDANet::Tensor& CUDA::conv2d_impl<float>(
|
||||||
|
const CUDANet::Tensor& weights,
|
||||||
|
const CUDANet::Tensor& biases,
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
const CUDANet::Shape in_shape,
|
||||||
|
const CUDANet::Shape padding_shape,
|
||||||
|
const CUDANet::Shape kernel_shape,
|
||||||
|
const CUDANet::Shape stride_shape,
|
||||||
|
const CUDANet::Shape out_shape
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
CUDANet::Tensor& CUDA::conv2d_impl(
|
||||||
|
const CUDANet::Tensor& weights,
|
||||||
|
const CUDANet::Tensor& biases,
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
const CUDANet::Shape in_shape,
|
||||||
|
const CUDANet::Shape padding_shape,
|
||||||
|
const CUDANet::Shape kernel_shape,
|
||||||
|
const CUDANet::Shape stride_shape,
|
||||||
|
const CUDANet::Shape out_shape
|
||||||
) {
|
) {
|
||||||
dim3 block(8, 8, 8);
|
dim3 block(8, 8, 8);
|
||||||
dim3 grid(
|
dim3 grid(
|
||||||
@@ -101,9 +213,8 @@ CUDANet::Tensor& CUDA::conv2d(
|
|||||||
);
|
);
|
||||||
|
|
||||||
Kernels::convolution<<<grid, block>>>(
|
Kernels::convolution<<<grid, block>>>(
|
||||||
input.data<float>(), weights.data<float>(), biases.data<float>(),
|
static_cast<T*>(input.device_ptr()), static_cast<T*>(weights.device_ptr()), static_cast<T*>(biases.device_ptr()), static_cast<T*>(output.device_ptr()),
|
||||||
output.data<float>(), in_shape, padding_shape, kernel_shape,
|
in_shape, padding_shape, kernel_shape, stride_shape, out_shape
|
||||||
stride_shape, out_shape
|
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
@@ -113,12 +224,46 @@ CUDANet::Tensor& CUDA::conv2d(
|
|||||||
|
|
||||||
CUDANet::Tensor& CUDA::max_pool2d(
|
CUDANet::Tensor& CUDA::max_pool2d(
|
||||||
const CUDANet::Tensor& input,
|
const CUDANet::Tensor& input,
|
||||||
CUDANet::Tensor& output,
|
CUDANet::Tensor& output,
|
||||||
CUDANet::Shape input_shape,
|
CUDANet::Shape input_shape,
|
||||||
CUDANet::Shape pool_shape,
|
CUDANet::Shape pool_shape,
|
||||||
CUDANet::Shape stride_shape,
|
CUDANet::Shape stride_shape,
|
||||||
CUDANet::Shape padding_shape,
|
CUDANet::Shape padding_shape,
|
||||||
CUDANet::Shape output_shape
|
CUDANet::Shape output_shape
|
||||||
|
) {
|
||||||
|
switch (input.get_dtype()) {
|
||||||
|
case DType::FLOAT32:
|
||||||
|
return max_pool2d_impl<float>(
|
||||||
|
input, output, input_shape, pool_shape, stride_shape,
|
||||||
|
padding_shape, output_shape
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unsupported dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template CUDANet::Tensor& CUDA::max_pool2d_impl<float>(
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
CUDANet::Shape input_shape,
|
||||||
|
CUDANet::Shape pool_shape,
|
||||||
|
CUDANet::Shape stride_shape,
|
||||||
|
CUDANet::Shape padding_shape,
|
||||||
|
CUDANet::Shape output_shape
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
CUDANet::Tensor& CUDA::max_pool2d_impl(
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
CUDANet::Shape input_shape,
|
||||||
|
CUDANet::Shape pool_shape,
|
||||||
|
CUDANet::Shape stride_shape,
|
||||||
|
CUDANet::Shape padding_shape,
|
||||||
|
CUDANet::Shape output_shape
|
||||||
) {
|
) {
|
||||||
dim3 block(8, 8, 8);
|
dim3 block(8, 8, 8);
|
||||||
dim3 grid(
|
dim3 grid(
|
||||||
@@ -128,8 +273,8 @@ CUDANet::Tensor& CUDA::max_pool2d(
|
|||||||
);
|
);
|
||||||
|
|
||||||
Kernels::max_pool<<<grid, block>>>(
|
Kernels::max_pool<<<grid, block>>>(
|
||||||
input.data<float>(), output.data<float>(), input_shape, output_shape, pool_shape,
|
static_cast<T*>(input.device_ptr()), static_cast<T*>(output.device_ptr()), input_shape, output_shape,
|
||||||
stride_shape, padding_shape
|
pool_shape, stride_shape, padding_shape
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
@@ -139,12 +284,46 @@ CUDANet::Tensor& CUDA::max_pool2d(
|
|||||||
|
|
||||||
CUDANet::Tensor& CUDA::avg_pool2d(
|
CUDANet::Tensor& CUDA::avg_pool2d(
|
||||||
const CUDANet::Tensor& input,
|
const CUDANet::Tensor& input,
|
||||||
CUDANet::Tensor& output,
|
CUDANet::Tensor& output,
|
||||||
CUDANet::Shape input_shape,
|
CUDANet::Shape input_shape,
|
||||||
CUDANet::Shape pool_shape,
|
CUDANet::Shape pool_shape,
|
||||||
CUDANet::Shape stride_shape,
|
CUDANet::Shape stride_shape,
|
||||||
CUDANet::Shape padding_shape,
|
CUDANet::Shape padding_shape,
|
||||||
CUDANet::Shape output_shape
|
CUDANet::Shape output_shape
|
||||||
|
) {
|
||||||
|
switch (input.get_dtype()) {
|
||||||
|
case DType::FLOAT32:
|
||||||
|
return avg_pool2d_impl<float>(
|
||||||
|
input, output, input_shape, pool_shape, stride_shape,
|
||||||
|
padding_shape, output_shape
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unsupported dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template CUDANet::Tensor& CUDA::avg_pool2d_impl<float>(
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
CUDANet::Shape input_shape,
|
||||||
|
CUDANet::Shape pool_shape,
|
||||||
|
CUDANet::Shape stride_shape,
|
||||||
|
CUDANet::Shape padding_shape,
|
||||||
|
CUDANet::Shape output_shape
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
CUDANet::Tensor& CUDA::avg_pool2d_impl(
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
CUDANet::Shape input_shape,
|
||||||
|
CUDANet::Shape pool_shape,
|
||||||
|
CUDANet::Shape stride_shape,
|
||||||
|
CUDANet::Shape padding_shape,
|
||||||
|
CUDANet::Shape output_shape
|
||||||
) {
|
) {
|
||||||
dim3 block(8, 8, 8);
|
dim3 block(8, 8, 8);
|
||||||
dim3 grid(
|
dim3 grid(
|
||||||
@@ -154,8 +333,8 @@ CUDANet::Tensor& CUDA::avg_pool2d(
|
|||||||
);
|
);
|
||||||
|
|
||||||
Kernels::avg_pool<<<grid, block>>>(
|
Kernels::avg_pool<<<grid, block>>>(
|
||||||
input.data<float>(), output.data<float>(), input_shape, output_shape, pool_shape,
|
static_cast<T*>(input.device_ptr()), static_cast<T*>(output.device_ptr()), input_shape, output_shape,
|
||||||
stride_shape, padding_shape
|
pool_shape, stride_shape, padding_shape
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
@@ -165,48 +344,84 @@ CUDANet::Tensor& CUDA::avg_pool2d(
|
|||||||
|
|
||||||
CUDANet::Tensor& CUDA::batch_norm(
|
CUDANet::Tensor& CUDA::batch_norm(
|
||||||
const CUDANet::Tensor& input,
|
const CUDANet::Tensor& input,
|
||||||
CUDANet::Tensor& output,
|
CUDANet::Tensor& output,
|
||||||
CUDANet::Shape input_shape,
|
CUDANet::Shape input_shape,
|
||||||
CUDANet::Tensor& weights,
|
CUDANet::Tensor& weights,
|
||||||
CUDANet::Tensor& biases,
|
CUDANet::Tensor& biases,
|
||||||
CUDANet::Tensor& running_mean,
|
CUDANet::Tensor& running_mean,
|
||||||
CUDANet::Tensor& running_var,
|
CUDANet::Tensor& running_var,
|
||||||
CUDANet::Tensor& epsilon
|
CUDANet::Tensor& epsilon
|
||||||
|
) {
|
||||||
|
switch (input.get_dtype()) {
|
||||||
|
case DType::FLOAT32:
|
||||||
|
return batch_norm_impl<float>(
|
||||||
|
input, output, input_shape, weights, biases, running_mean,
|
||||||
|
running_var, epsilon
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unsupported dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template CUDANet::Tensor& CUDA::batch_norm_impl<float>(
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
CUDANet::Shape input_shape,
|
||||||
|
CUDANet::Tensor& weights,
|
||||||
|
CUDANet::Tensor& biases,
|
||||||
|
CUDANet::Tensor& running_mean,
|
||||||
|
CUDANet::Tensor& running_var,
|
||||||
|
CUDANet::Tensor& epsilon
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
CUDANet::Tensor& CUDA::batch_norm_impl(
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
CUDANet::Shape input_shape,
|
||||||
|
CUDANet::Tensor& weights,
|
||||||
|
CUDANet::Tensor& biases,
|
||||||
|
CUDANet::Tensor& running_mean,
|
||||||
|
CUDANet::Tensor& running_var,
|
||||||
|
CUDANet::Tensor& epsilon
|
||||||
) {
|
) {
|
||||||
auto gridSize =
|
auto gridSize =
|
||||||
(input_shape[0] * input_shape[1] + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
(input_shape[0] * input_shape[1] + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
|
|
||||||
|
|
||||||
for (int i = 0; i < input_shape[2]; i++) {
|
for (int i = 0; i < input_shape[2]; i++) {
|
||||||
// Subtract mean from input
|
// Subtract mean from input
|
||||||
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
|
||||||
input.data<float>() + i * input_shape[0] * input_shape[1],
|
static_cast<T*>(input.device_ptr()) + i * input_shape[0] * input_shape[1],
|
||||||
output.data<float>() + i * input_shape[0] * input_shape[1],
|
static_cast<T*>(output.device_ptr()) + i * input_shape[0] * input_shape[1],
|
||||||
&running_mean.data<float>()[i], input_shape[0] * input_shape[1]
|
&static_cast<T*>(running_mean.device_ptr())[i], input_shape[0] * input_shape[1]
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
// Divide by sqrt(running_var + epsilon)
|
// Divide by sqrt(running_var + epsilon)
|
||||||
Kernels::vec_scale<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::vec_scale<<<gridSize, BLOCK_SIZE>>>(
|
||||||
output.data<float>() + i * input_shape[0] * input_shape[1],
|
static_cast<T*>(output.device_ptr()) + i * input_shape[0] * input_shape[1],
|
||||||
output.data<float>() + i * input_shape[0] * input_shape[1],
|
static_cast<T*>(output.device_ptr()) + i * input_shape[0] * input_shape[1],
|
||||||
&running_var.data<float>()[i], epsilon.data<float>(), input_shape[0] * input_shape[1]
|
&static_cast<T*>(running_var.device_ptr())[i], static_cast<T*>(epsilon.device_ptr()),
|
||||||
|
input_shape[0] * input_shape[1]
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
// Multiply by weights
|
// Multiply by weights
|
||||||
Kernels::vec_scalar_mul<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::vec_scalar_mul<<<gridSize, BLOCK_SIZE>>>(
|
||||||
output.data<float>() + i * input_shape[0] * input_shape[1],
|
static_cast<T*>(output.device_ptr()) + i * input_shape[0] * input_shape[1],
|
||||||
output.data<float>() + i * input_shape[0] * input_shape[1], &weights.data<float>()[i],
|
static_cast<T*>(output.device_ptr()) + i * input_shape[0] * input_shape[1],
|
||||||
input_shape[0] * input_shape[1]
|
&static_cast<T*>(weights.device_ptr())[i], input_shape[0] * input_shape[1]
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
// Add biases
|
// Add biases
|
||||||
Kernels::vec_scalar_add<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::vec_scalar_add<<<gridSize, BLOCK_SIZE>>>(
|
||||||
output.data<float>() + i * input_shape[0] * input_shape[1],
|
static_cast<T*>(output.device_ptr()) + i * input_shape[0] * input_shape[1],
|
||||||
output.data<float>() + i * input_shape[0] * input_shape[1], &biases.data<float>()[i],
|
static_cast<T*>(output.device_ptr()) + i * input_shape[0] * input_shape[1],
|
||||||
input_shape[0] * input_shape[1]
|
&static_cast<T*>(biases.device_ptr())[i], input_shape[0] * input_shape[1]
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
}
|
}
|
||||||
@@ -218,14 +433,39 @@ CUDANet::Tensor& CUDA::concat(
|
|||||||
CUDANet::Tensor& input_a,
|
CUDANet::Tensor& input_a,
|
||||||
CUDANet::Tensor& input_b,
|
CUDANet::Tensor& input_b,
|
||||||
CUDANet::Tensor& output
|
CUDANet::Tensor& output
|
||||||
|
) {
|
||||||
|
switch (input_a.get_dtype()) {
|
||||||
|
case DType::FLOAT32:
|
||||||
|
return concat_impl<float>(
|
||||||
|
input_a, input_b, output
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unsupported dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template CUDANet::Tensor& CUDA::concat_impl<float>(
|
||||||
|
CUDANet::Tensor& input_a,
|
||||||
|
CUDANet::Tensor& input_b,
|
||||||
|
CUDANet::Tensor& output
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
CUDANet::Tensor& CUDA::concat_impl(
|
||||||
|
CUDANet::Tensor& input_a,
|
||||||
|
CUDANet::Tensor& input_b,
|
||||||
|
CUDANet::Tensor& output
|
||||||
) {
|
) {
|
||||||
CUDA_CHECK(cudaMemcpy(
|
CUDA_CHECK(cudaMemcpy(
|
||||||
output.data<float>(), input_a.data<float>(), input_a.size(),
|
static_cast<T*>(output.device_ptr()), static_cast<T*>(input_a.device_ptr()), input_a.size(),
|
||||||
cudaMemcpyDeviceToDevice
|
cudaMemcpyDeviceToDevice
|
||||||
));
|
));
|
||||||
|
|
||||||
CUDA_CHECK(cudaMemcpy(
|
CUDA_CHECK(cudaMemcpy(
|
||||||
output.data<float>() + input_a.numel(), input_b.data<float>(), input_b.size(),
|
static_cast<T*>(output.device_ptr()) + input_a.numel(), static_cast<T*>(input_b.device_ptr()), input_b.size(),
|
||||||
cudaMemcpyDeviceToDevice
|
cudaMemcpyDeviceToDevice
|
||||||
));
|
));
|
||||||
|
|
||||||
@@ -239,11 +479,36 @@ CUDANet::Tensor& CUDA::add(
|
|||||||
CUDANet::Tensor& input_a,
|
CUDANet::Tensor& input_a,
|
||||||
CUDANet::Tensor& input_b,
|
CUDANet::Tensor& input_b,
|
||||||
CUDANet::Tensor& output
|
CUDANet::Tensor& output
|
||||||
|
) {
|
||||||
|
switch (input_a.get_dtype()) {
|
||||||
|
case DType::FLOAT32:
|
||||||
|
return add_impl<float>(
|
||||||
|
input_a, input_b, output
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unsupported dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template CUDANet::Tensor& CUDA::add_impl<float>(
|
||||||
|
CUDANet::Tensor& input_a,
|
||||||
|
CUDANet::Tensor& input_b,
|
||||||
|
CUDANet::Tensor& output
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
CUDANet::Tensor& CUDA::add_impl(
|
||||||
|
CUDANet::Tensor& input_a,
|
||||||
|
CUDANet::Tensor& input_b,
|
||||||
|
CUDANet::Tensor& output
|
||||||
) {
|
) {
|
||||||
auto gridSize = (input_a.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
auto gridSize = (input_a.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
|
|
||||||
Kernels::vec_vec_add<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::vec_vec_add<<<gridSize, BLOCK_SIZE>>>(
|
||||||
input_a.data<float>(), input_b.data<float>(), output.data<float>(), input_a.numel()
|
static_cast<T*>(input_a.device_ptr()), static_cast<T*>(input_b.device_ptr()), static_cast<T*>(output.device_ptr()), input_a.numel()
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
|||||||
@@ -7,11 +7,26 @@
|
|||||||
using namespace CUDANet::Backends;
|
using namespace CUDANet::Backends;
|
||||||
|
|
||||||
void CUDA::print(const CUDANet::Tensor &input) {
|
void CUDA::print(const CUDANet::Tensor &input) {
|
||||||
|
switch (input.get_dtype()) {
|
||||||
|
case DType::FLOAT32:
|
||||||
|
print_impl<float>(input);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unsupported dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template void CUDA::print_impl<float> (const CUDANet::Tensor &input);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void CUDA::print_impl(const CUDANet::Tensor &input) {
|
||||||
auto length = input.numel();
|
auto length = input.numel();
|
||||||
std::vector<float> h_vec(input.numel());
|
std::vector<T> h_vec(input.numel());
|
||||||
|
|
||||||
CUDA_CHECK(cudaMemcpy(
|
CUDA_CHECK(cudaMemcpy(
|
||||||
h_vec.data(), input.data<float>(), sizeof(float) * length, cudaMemcpyDeviceToHost
|
h_vec.data(), static_cast<T*>(input.device_ptr()), sizeof(T) * length, cudaMemcpyDeviceToHost
|
||||||
));
|
));
|
||||||
|
|
||||||
for (int i = 0; i < length; ++i) {
|
for (int i = 0; i < length; ++i) {
|
||||||
@@ -26,27 +41,71 @@ void CUDA::zero(CUDANet::Tensor &input) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CUDA::fill(CUDANet::Tensor &input, int value) {
|
void CUDA::fill(CUDANet::Tensor &input, int value) {
|
||||||
CUDA_CHECK(cudaMemset(input.data<float>(), value, sizeof(float) * input.numel()));
|
switch (input.get_dtype()) {
|
||||||
|
case DType::FLOAT32:
|
||||||
|
fill_impl<float>(input, value);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unsupported dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template void CUDA::fill_impl<float>(CUDANet::Tensor &input, int value);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void CUDA::fill_impl(CUDANet::Tensor &input, int value) {
|
||||||
|
CUDA_CHECK(cudaMemset(static_cast<T*>(input.device_ptr()), value, sizeof(T) * input.numel()));
|
||||||
}
|
}
|
||||||
|
|
||||||
void CUDA::copy_to_device(CUDANet::Tensor &tensor, void *data, size_t size) {
|
void CUDA::copy_to_device(CUDANet::Tensor &tensor, void *data, size_t size) {
|
||||||
CUDA_CHECK(cudaMemcpy(tensor.data<float>(), data, size, cudaMemcpyHostToDevice));
|
switch (tensor.get_dtype()) {
|
||||||
|
case DType::FLOAT32:
|
||||||
|
copy_to_device_impl<float>(tensor, data, size);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unsupported dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template void CUDA::copy_to_device_impl<float>(CUDANet::Tensor &tensor, void *data, size_t size);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void CUDA::copy_to_device_impl(CUDANet::Tensor &tensor, void *data, size_t size) {
|
||||||
|
CUDA_CHECK(cudaMemcpy(static_cast<T*>(tensor.device_ptr()), data, size, cudaMemcpyHostToDevice));
|
||||||
}
|
}
|
||||||
|
|
||||||
void CUDA::sum(const CUDANet::Tensor &input, CUDANet::Tensor &sum) {
|
void CUDA::sum(const CUDANet::Tensor &input, CUDANet::Tensor &sum) {
|
||||||
|
switch (input.get_dtype()) {
|
||||||
|
case DType::FLOAT32:
|
||||||
|
sum_impl<float>(input, sum);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unsupported dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template void CUDA::sum_impl<float>(const CUDANet::Tensor &input, CUDANet::Tensor &sum);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void CUDA::sum_impl(const CUDANet::Tensor &input, CUDANet::Tensor &sum) {
|
||||||
auto length = input.numel();
|
auto length = input.numel();
|
||||||
const int gridSize = ( + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
|
|
||||||
CUDANet::Kernels::sum_reduce<<<gridSize, BLOCK_SIZE>>>(
|
CUDANet::Kernels::sum_reduce<<<gridSize, BLOCK_SIZE>>>(
|
||||||
input.data<float>(), sum.data<float>(), length
|
static_cast<T*>(input.device_ptr()), static_cast<T*>(sum.device_ptr()), length
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
int remaining = gridSize;
|
int remaining = gridSize;
|
||||||
while (remaining > 1) {
|
while (remaining > 1) {
|
||||||
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
CUDANet::Kernels::sum_reduce<<<blocks_needed, BLOCK_SIZE>>>(sum.data<float>(), sum.data<float>(), remaining);
|
CUDANet::Kernels::sum_reduce<<<blocks_needed, BLOCK_SIZE>>>(static_cast<T*>(sum.device_ptr()), static_cast<T*>(sum.device_ptr()), remaining);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
remaining = blocks_needed;
|
remaining = blocks_needed;
|
||||||
@@ -54,17 +113,32 @@ void CUDA::sum(const CUDANet::Tensor &input, CUDANet::Tensor &sum) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CUDA::max(const CUDANet::Tensor &input, CUDANet::Tensor &max) {
|
void CUDA::max(const CUDANet::Tensor &input, CUDANet::Tensor &max) {
|
||||||
|
switch (input.get_dtype()) {
|
||||||
|
case DType::FLOAT32:
|
||||||
|
max_impl<float>(input, max);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unsupported dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template void CUDA::max_impl<float>(const CUDANet::Tensor &input, CUDANet::Tensor &max);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void CUDA::max_impl(const CUDANet::Tensor &input, CUDANet::Tensor &max) {
|
||||||
auto length = input.numel();
|
auto length = input.numel();
|
||||||
const int grid_size = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
const int grid_size = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
|
|
||||||
Kernels::max_reduce<<<grid_size, BLOCK_SIZE>>>(input.data<float>(), max.data<float>(), length);
|
Kernels::max_reduce<<<grid_size, BLOCK_SIZE>>>(static_cast<T*>(input.device_ptr()), static_cast<T*>(max.device_ptr()), length);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
int remaining = grid_size;
|
int remaining = grid_size;
|
||||||
|
|
||||||
while (remaining > 1) {
|
while (remaining > 1) {
|
||||||
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
CUDANet::Kernels::max_reduce<<<blocks_needed, BLOCK_SIZE>>>(max.data<float>(), max.data<float>(), remaining);
|
CUDANet::Kernels::max_reduce<<<blocks_needed, BLOCK_SIZE>>>(static_cast<T*>(max.device_ptr()), static_cast<T*>(max.device_ptr()), remaining);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
remaining = blocks_needed;
|
remaining = blocks_needed;
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
#include "activation.hpp"
|
|
||||||
|
|
||||||
#include <format>
|
#include <format>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "layers/activation.hpp"
|
||||||
#include "tensor.hpp"
|
#include "tensor.hpp"
|
||||||
|
|
||||||
|
|
||||||
using namespace CUDANet::Layers;
|
using namespace CUDANet::Layers;
|
||||||
|
|
||||||
Activation::Activation(
|
Activation::Activation(
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
#include "add.hpp"
|
#include "layers/add.hpp"
|
||||||
|
|
||||||
using namespace CUDANet::Layers;
|
using namespace CUDANet::Layers;
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
|
#include <format>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
#include "avg_pool.hpp"
|
#include "layers/avg_pool.hpp"
|
||||||
#include <format>
|
|
||||||
|
|
||||||
using namespace CUDANet::Layers;
|
using namespace CUDANet::Layers;
|
||||||
|
|
||||||
@@ -84,11 +84,11 @@ CUDANet::Shape AvgPool2d::output_shape() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t AvgPool2d::input_size() {
|
size_t AvgPool2d::input_size() {
|
||||||
return sizeof(float) * in_shape[0] * in_shape[1] * in_shape[2];
|
return dtype_size(dtype) * in_shape[0] * in_shape[1] * in_shape[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t AvgPool2d::output_size() {
|
size_t AvgPool2d::output_size() {
|
||||||
return sizeof(float) * out_shape[0] * out_shape[1] * out_shape[2];
|
return dtype_size(dtype) * out_shape[0] * out_shape[1] * out_shape[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
void AvgPool2d::set_weights(void* input) {}
|
void AvgPool2d::set_weights(void* input) {}
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
#include "batch_norm.hpp"
|
|
||||||
|
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "activation.hpp"
|
#include "layers/batch_norm.hpp"
|
||||||
#include "layer.hpp"
|
#include "layer.hpp"
|
||||||
|
|
||||||
using namespace CUDANet::Layers;
|
using namespace CUDANet::Layers;
|
||||||
@@ -30,7 +28,7 @@ BatchNorm2d::BatchNorm2d(
|
|||||||
this->dtype = dtype;
|
this->dtype = dtype;
|
||||||
|
|
||||||
epsilon = CUDANet::Tensor({1}, dtype, backend);
|
epsilon = CUDANet::Tensor({1}, dtype, backend);
|
||||||
epsilon.set_data<float>(&eps);
|
epsilon.set_data(&eps);
|
||||||
|
|
||||||
running_mean = CUDANet::Tensor({in_shape[2]}, dtype, backend);
|
running_mean = CUDANet::Tensor({in_shape[2]}, dtype, backend);
|
||||||
running_mean.zero();
|
running_mean.zero();
|
||||||
@@ -73,15 +71,15 @@ CUDANet::Shape BatchNorm2d::output_shape() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t BatchNorm2d::input_size() {
|
size_t BatchNorm2d::input_size() {
|
||||||
return sizeof(float) * in_shape[0] * in_shape[1] * in_shape[2];
|
return dtype_size(dtype) * in_shape[0] * in_shape[1] * in_shape[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t BatchNorm2d::output_size() {
|
size_t BatchNorm2d::output_size() {
|
||||||
return sizeof(float) * in_shape[0] * in_shape[1] * in_shape[2];
|
return dtype_size(dtype) * in_shape[0] * in_shape[1] * in_shape[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
void BatchNorm2d::set_weights(void* input) {
|
void BatchNorm2d::set_weights(void* input) {
|
||||||
weights.set_data<float>(static_cast<float*>(input));
|
weights.set_data(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t BatchNorm2d::get_weights_size() {
|
size_t BatchNorm2d::get_weights_size() {
|
||||||
@@ -89,7 +87,7 @@ size_t BatchNorm2d::get_weights_size() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void BatchNorm2d::set_biases(void* input) {
|
void BatchNorm2d::set_biases(void* input) {
|
||||||
biases.set_data<float>(static_cast<float*>(input));
|
biases.set_data(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t BatchNorm2d::get_biases_size() {
|
size_t BatchNorm2d::get_biases_size() {
|
||||||
@@ -97,7 +95,7 @@ size_t BatchNorm2d::get_biases_size() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void BatchNorm2d::set_running_mean(void* input) {
|
void BatchNorm2d::set_running_mean(void* input) {
|
||||||
running_mean.set_data<float>(static_cast<float*>(input));
|
running_mean.set_data(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t BatchNorm2d::get_running_mean_size() {
|
size_t BatchNorm2d::get_running_mean_size() {
|
||||||
@@ -105,7 +103,7 @@ size_t BatchNorm2d::get_running_mean_size() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void BatchNorm2d::set_running_var(void* input) {
|
void BatchNorm2d::set_running_var(void* input) {
|
||||||
running_var.set_data<float>(static_cast<float*>(input));
|
running_var.set_data(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t BatchNorm2d::get_running_var_size() {
|
size_t BatchNorm2d::get_running_var_size() {
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
#include "concat.hpp"
|
#include "layers/concat.hpp"
|
||||||
|
|
||||||
using namespace CUDANet::Layers;
|
using namespace CUDANet::Layers;
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
#include "conv2d.hpp"
|
|
||||||
|
|
||||||
#include <format>
|
#include <format>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
|
#include "layers/conv2d.hpp"
|
||||||
#include "layer.hpp"
|
#include "layer.hpp"
|
||||||
#include "tensor.hpp"
|
#include "tensor.hpp"
|
||||||
|
|
||||||
@@ -97,15 +96,15 @@ CUDANet::Shape Conv2d::output_shape() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t Conv2d::input_size() {
|
size_t Conv2d::input_size() {
|
||||||
return sizeof(float) * in_shape[0] * in_shape[1] * in_shape[2];
|
return dtype_size(dtype) * in_shape[0] * in_shape[1] * in_shape[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t Conv2d::output_size() {
|
size_t Conv2d::output_size() {
|
||||||
return sizeof(float) * out_shape[0] * out_shape[1] * out_shape[2];
|
return dtype_size(dtype) * out_shape[0] * out_shape[1] * out_shape[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
void Conv2d::set_weights(void* input) {
|
void Conv2d::set_weights(void* input) {
|
||||||
weights.set_data<float>(static_cast<float*>(input));
|
weights.set_data(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t Conv2d::get_weights_size() {
|
size_t Conv2d::get_weights_size() {
|
||||||
@@ -113,7 +112,7 @@ size_t Conv2d::get_weights_size() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Conv2d::set_biases(void* input) {
|
void Conv2d::set_biases(void* input) {
|
||||||
biases.set_data<float>(static_cast<float*>(input));
|
biases.set_data(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t Conv2d::get_biases_size() {
|
size_t Conv2d::get_biases_size() {
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
#include "dense.hpp"
|
|
||||||
|
|
||||||
#include <format>
|
#include <format>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
|
#include "layers/dense.hpp"
|
||||||
|
|
||||||
using namespace CUDANet::Layers;
|
using namespace CUDANet::Layers;
|
||||||
|
|
||||||
Dense::Dense(CUDANet::Shape in_shape, CUDANet::Shape out_shape, CUDANet::Backend* backend)
|
Dense::Dense(CUDANet::Shape in_shape, CUDANet::Shape out_shape, CUDANet::Backend* backend)
|
||||||
@@ -56,8 +56,9 @@ size_t Dense::output_size() {
|
|||||||
return out_shape[0];
|
return out_shape[0];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TODO: Use dtype
|
||||||
void Dense::set_weights(void* input) {
|
void Dense::set_weights(void* input) {
|
||||||
weights.set_data<float>(static_cast<float*>(input));
|
weights.set_data(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t Dense::get_weights_size() {
|
size_t Dense::get_weights_size() {
|
||||||
@@ -65,7 +66,7 @@ size_t Dense::get_weights_size() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Dense::set_biases(void* input) {
|
void Dense::set_biases(void* input) {
|
||||||
biases.set_data<float>(static_cast<float*>(input));
|
biases.set_data(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t Dense::get_biases_size() {
|
size_t Dense::get_biases_size() {
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#include "max_pool.hpp"
|
|
||||||
|
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
|
#include "layers/max_pool.hpp"
|
||||||
|
|
||||||
using namespace CUDANet::Layers;
|
using namespace CUDANet::Layers;
|
||||||
|
|
||||||
MaxPool2d::MaxPool2d(
|
MaxPool2d::MaxPool2d(
|
||||||
@@ -78,11 +78,11 @@ CUDANet::Shape MaxPool2d::output_shape() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t MaxPool2d::input_size() {
|
size_t MaxPool2d::input_size() {
|
||||||
return sizeof(float) * in_shape[0] * in_shape[1] * in_shape[2];
|
return dtype_size(dtype) * in_shape[0] * in_shape[1] * in_shape[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t MaxPool2d::output_size() {
|
size_t MaxPool2d::output_size() {
|
||||||
return sizeof(float) * out_shape[0] * out_shape[1] * out_shape[2];
|
return dtype_size(dtype) * out_shape[0] * out_shape[1] * out_shape[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
void MaxPool2d::set_weights(void* input) {}
|
void MaxPool2d::set_weights(void* input) {}
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
#include "model.hpp"
|
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
@@ -8,7 +6,9 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "layer.hpp"
|
#include "layer.hpp"
|
||||||
#include "batch_norm.hpp"
|
#include "layers/batch_norm.hpp"
|
||||||
|
|
||||||
|
#include "model.hpp"
|
||||||
|
|
||||||
using namespace CUDANet;
|
using namespace CUDANet;
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#include "module.hpp"
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "module.hpp"
|
||||||
|
|
||||||
using namespace CUDANet;
|
using namespace CUDANet;
|
||||||
|
|
||||||
CUDANet::Shape Module::input_shape() {
|
CUDANet::Shape Module::input_shape() {
|
||||||
@@ -12,22 +12,6 @@ CUDANet::Shape Module::output_shape() {
|
|||||||
return out_shape;
|
return out_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t Module::input_size() {
|
|
||||||
size_t count = 1;
|
|
||||||
for (const auto& dim : in_shape) {
|
|
||||||
count *= dim;
|
|
||||||
}
|
|
||||||
return sizeof(float) * count;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t Module::output_size() {
|
|
||||||
size_t count = 1;
|
|
||||||
for (const auto& dim : out_shape) {
|
|
||||||
count *= dim;
|
|
||||||
}
|
|
||||||
return sizeof(float) * count;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Module::register_layer(const std::string& name, Layer* layer) {
|
void Module::register_layer(const std::string& name, Layer* layer) {
|
||||||
layers.push_back({name, layer});
|
layers.push_back({name, layer});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#include "tensor.hpp"
|
|
||||||
|
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
|
#include "tensor.hpp"
|
||||||
|
|
||||||
using namespace CUDANet;
|
using namespace CUDANet;
|
||||||
|
|
||||||
Tensor::Tensor(Shape shape, CUDANet::Backend* backend)
|
Tensor::Tensor(Shape shape, CUDANet::Backend* backend)
|
||||||
@@ -80,6 +80,10 @@ Tensor::~Tensor() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DType Tensor::get_dtype() const {
|
||||||
|
return dtype;
|
||||||
|
}
|
||||||
|
|
||||||
size_t Tensor::numel() const {
|
size_t Tensor::numel() const {
|
||||||
return total_elms;
|
return total_elms;
|
||||||
}
|
}
|
||||||
@@ -88,6 +92,22 @@ size_t Tensor::size() const {
|
|||||||
return total_size;
|
return total_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void* Tensor::device_ptr() const {
|
||||||
|
return d_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* Tensor::device_ptr() {
|
||||||
|
return d_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
void Tensor::zero() {
|
void Tensor::zero() {
|
||||||
backend->zero(*this);
|
backend->zero(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Tensor::fill(int value) {
|
||||||
|
backend->fill(*this, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Tensor::set_data(void *data) {
|
||||||
|
backend->copy_to_device(*this, data, total_size);
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user