mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-05 17:34:21 +00:00
Compute mean and variance
This commit is contained in:
@@ -135,6 +135,20 @@ __global__ void vec_exp(
|
||||
const unsigned int len
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief Compute the square root of each element of the vector
|
||||
*
|
||||
* @param src Device pointer to source vector
|
||||
* @param dst Device pointer to destination vector
|
||||
* @param len Length of the vector
|
||||
* @return __global__
|
||||
*/
|
||||
__global__ void vec_sqrt(
|
||||
const float* __restrict__ src,
|
||||
float* __restrict__ dst,
|
||||
const unsigned int len
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief Max reduction kernel
|
||||
*
|
||||
|
||||
@@ -74,8 +74,12 @@ class BatchNorm : public WeightedLayer {
|
||||
float* d_output;
|
||||
|
||||
float* d_mean;
|
||||
float* d_mean_sub;
|
||||
float* d_sqrt_var;
|
||||
|
||||
float* d_length;
|
||||
float* d_epsilon;
|
||||
|
||||
float* d_weights;
|
||||
float* d_biases;
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ namespace CUDANet::Utils {
|
||||
* @param d_vec Pointer to the vector on device
|
||||
* @param length Length of the vector
|
||||
*/
|
||||
void print_vec(float *d_vec, const unsigned int length);
|
||||
void print_vec(const float *d_vec, const unsigned int length);
|
||||
|
||||
/**
|
||||
* @brief Utility function that clears a vector
|
||||
@@ -27,7 +27,7 @@ void clear(float *d_vector, const unsigned int len);
|
||||
* @param d_vec Pointer to the vector
|
||||
* @param length Length of the vector
|
||||
*/
|
||||
void sum(float *d_vec, float *d_sum, const unsigned int length);
|
||||
void sum(const float *d_vec, float *d_sum, const unsigned int length);
|
||||
|
||||
|
||||
/**
|
||||
@@ -36,25 +36,16 @@ void sum(float *d_vec, float *d_sum, const unsigned int length);
|
||||
* @param d_vec Pointer to the vector
|
||||
* @param length Length of the vector
|
||||
*/
|
||||
void max(float *d_vec, float *d_max, const unsigned int length);
|
||||
void max(const float *d_vec, float *d_max, const unsigned int length);
|
||||
|
||||
/**
|
||||
* @brief Compute the mean of a vector
|
||||
*
|
||||
* @param d_vec
|
||||
* @param d_mean
|
||||
* @param length
|
||||
*/
|
||||
void mean(float *d_vec, float *d_mean, const unsigned int length);
|
||||
|
||||
/**
|
||||
* @brief Compute the variance of a vector
|
||||
*
|
||||
* @param d_vec
|
||||
* @param d_var
|
||||
* @param length
|
||||
*/
|
||||
void var(float *d_vec, float *d_var, const unsigned int length);
|
||||
// /**
|
||||
// * @brief Compute the variance of a vector
|
||||
// *
|
||||
// * @param d_vec
|
||||
// * @param d_var
|
||||
// * @param length
|
||||
// */
|
||||
// void var(float *d_vec, float *d_var, const unsigned int length);
|
||||
|
||||
} // namespace CUDANet::Utils
|
||||
|
||||
|
||||
Reference in New Issue
Block a user