Improve softmax numerical stability

This commit is contained in:
2024-04-08 23:25:46 +02:00
parent e419a93408
commit b49dddf34a
6 changed files with 119 additions and 4 deletions

View File

@@ -35,6 +35,33 @@ __global__ void vec_vec_add(
const unsigned int w
);
/**
* @brief Max reduction kernel
*
* @param d_vector Device pointer to vector
* @param d_output Device pointer to output vector
*/
__global__ void max_reduce(
const float* __restrict__ d_vector,
float* __restrict__ d_output
);
/**
* @brief Add scalar to each element of the vector
*
* @param d_vector
* @param d_scalar
* @param d_output
* @param w
* @return __global__
*/
__global__ void vec_scalar_sub(
const float* __restrict__ d_vector,
const float* __restrict__ d_scalar,
float* __restrict__ d_output,
const unsigned int w
);
} // namespace CUDANet::Kernels
#endif // CUDANET_MATMUL_H