mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-05 17:34:21 +00:00
Load running mean and var from weight file
This commit is contained in:
@@ -57,6 +57,12 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer {
|
||||
*/
|
||||
void setRunningMean(const float* running_mean_input);
|
||||
|
||||
/**
|
||||
* @brief Get the Running Mean
|
||||
*
|
||||
*/
|
||||
std::vector<float> getRunningMean();
|
||||
|
||||
/**
|
||||
* @brief Set the Running Var
|
||||
*
|
||||
@@ -64,6 +70,12 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer {
|
||||
*/
|
||||
void setRunningVar(const float* running_mean_input);
|
||||
|
||||
/**
|
||||
* @brief Get the Running Var
|
||||
*
|
||||
*/
|
||||
std::vector<float> getRunningVar();
|
||||
|
||||
/**
|
||||
* @brief Get output size
|
||||
*
|
||||
|
||||
@@ -15,6 +15,8 @@ namespace CUDANet {
|
||||
enum TensorType {
|
||||
WEIGHT,
|
||||
BIAS,
|
||||
RUNNING_MEAN,
|
||||
RUNNING_VAR
|
||||
};
|
||||
|
||||
struct TensorInfo {
|
||||
|
||||
Reference in New Issue
Block a user