Move softmax partial kernels to matmul

This commit is contained in:
2024-04-11 22:01:47 +02:00
parent bf7c961b9e
commit 710a33bdde
6 changed files with 274 additions and 212 deletions

View File

@@ -35,17 +35,6 @@ __global__ void vec_vec_add(
const unsigned int w
);
/**
* @brief Max reduction kernel
*
* @param d_vector Device pointer to vector
* @param d_output Device pointer to output vector
*/
__global__ void max_reduce(
const float* __restrict__ d_vector,
float* __restrict__ d_output
);
/**
* @brief Add scalar to each element of the vector
*
@@ -56,15 +45,68 @@ __global__ void max_reduce(
* @return __global__
*/
__global__ void vec_scalar_sub(
const float* __restrict__ d_vector,
const float* __restrict__ d_src,
float* __restrict__ d_out,
const float* __restrict__ d_scalar,
float* __restrict__ d_output,
const unsigned int w
const unsigned int len
);
/**
* @brief Softmax activation function kernel
*
* @param src Pointer to the source array
* @param dst Pointer to the destination array
* @param len Length of the arrays
*/
__global__ void vec_scalar_div(
const float* __restrict__ d_src,
float* __restrict__ d_out,
const float* __restrict__ d_scalar,
const unsigned int len
);
/**
* @brief Softmax activation exponentiation kernel
*
* @param src Pointer to the source array
* @param dst Pointer to the destination array
* @param len Length of the arrays
*/
__global__ void vec_exp(
const float* __restrict__ src,
float* __restrict__ dst,
const unsigned int len
);
/**
* @brief Max reduction kernel
*
* @param d_vector Device pointer to vector
* @param d_output Device pointer to output vector
*/
__global__ void max_reduce(
const float* __restrict__ d_vector,
float* __restrict__ d_output,
const unsigned int len
);
/**
* @brief
*
* @param d_vector Device pointer to vector
* @param d_output Device pointer to output vector
* @param len Length of the vector
*/
__global__ void sum_reduce(
const float* __restrict__ d_vector,
float* __restrict__ d_output,
const unsigned int len
);
__global__ void clear(
float* __restrict__ d_vector,
const unsigned int w
const unsigned int len
);
} // namespace CUDANet::Kernels