Load running mean and var from weight file

This commit is contained in:
2024-08-25 19:33:33 +02:00
parent 9704d0d53e
commit bc9bff10cd
4 changed files with 56 additions and 3 deletions

View File

@@ -57,6 +57,12 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer {
*/
void setRunningMean(const float* running_mean_input);
/**
* @brief Get the Running Mean
*
*/
std::vector<float> getRunningMean();
/**
* @brief Set the Running Var
*
@@ -64,6 +70,12 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer {
*/
void setRunningVar(const float* running_mean_input);
/**
* @brief Get the Running Var
*
*/
std::vector<float> getRunningVar();
/**
* @brief Get output size
*

View File

@@ -15,6 +15,8 @@ namespace CUDANet {
enum TensorType {
WEIGHT,
BIAS,
RUNNING_MEAN,
RUNNING_VAR
};
struct TensorInfo {