From bc9bff10cdc8436426e5a55324e3db7b7f104051 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sun, 25 Aug 2024 19:33:33 +0200 Subject: [PATCH] Load running mean and var from weight file --- include/layers/batch_norm.cuh | 12 ++++++++++++ include/model/model.hpp | 2 ++ src/layers/batch_norm.cu | 8 ++++++++ src/model/model.cpp | 37 ++++++++++++++++++++++++++++++++--- 4 files changed, 56 insertions(+), 3 deletions(-) diff --git a/include/layers/batch_norm.cuh b/include/layers/batch_norm.cuh index 1796c54..c940d93 100644 --- a/include/layers/batch_norm.cuh +++ b/include/layers/batch_norm.cuh @@ -57,6 +57,12 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer { */ void setRunningMean(const float* running_mean_input); + /** + * @brief Get the Running Mean + * + */ + std::vector getRunningMean(); + /** * @brief Set the Running Var * @@ -64,6 +70,12 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer { */ void setRunningVar(const float* running_mean_input); + /** + * @brief Get the Running Var + * + */ + std::vector getRunningVar(); + /** * @brief Get output size * diff --git a/include/model/model.hpp b/include/model/model.hpp index 78a6f10..5a2d777 100644 --- a/include/model/model.hpp +++ b/include/model/model.hpp @@ -15,6 +15,8 @@ namespace CUDANet { enum TensorType { WEIGHT, BIAS, + RUNNING_MEAN, + RUNNING_VAR }; struct TensorInfo { diff --git a/src/layers/batch_norm.cu b/src/layers/batch_norm.cu index 9b82606..b086525 100644 --- a/src/layers/batch_norm.cu +++ b/src/layers/batch_norm.cu @@ -121,11 +121,19 @@ void BatchNorm2d::setRunningMean(const float* running_mean_input) { toCuda(); } +std::vector BatchNorm2d::getRunningMean() { + return running_mean; +} + void BatchNorm2d::setRunningVar(const float* running_var_input) { std::copy(running_var_input, running_var_input + inputChannels, running_var.begin()); toCuda(); } +std::vector BatchNorm2d::getRunningVar() { + return running_var; +} + void BatchNorm2d::toCuda() { CUDA_CHECK(cudaMemcpy( d_weights, weights.data(), sizeof(float) * inputChannels, diff --git a/src/model/model.cpp b/src/model/model.cpp index 89d65ac..ba5b88c 100644 --- a/src/model/model.cpp +++ b/src/model/model.cpp @@ -9,6 +9,7 @@ #include "input.cuh" #include "layer.cuh" +#include "batch_norm.cuh" using namespace CUDANet; @@ -91,6 +92,14 @@ void Model::loadWeights(const std::string& path) { return; } + auto getTensorType = [](const std::string& typeStr) { + if (typeStr == "weight") return TensorType::WEIGHT; + if (typeStr == "bias") return TensorType::BIAS; + if (typeStr == "running_mean") return TensorType::RUNNING_MEAN; + if (typeStr == "running_var") return TensorType::RUNNING_VAR; + throw std::runtime_error("Unknown tensor type: " + typeStr); + }; + u_int64_t headerSize; file.read(reinterpret_cast(&headerSize), sizeof(headerSize)); @@ -115,9 +124,8 @@ void Model::loadWeights(const std::string& path) { size_t dotPos = nameStr.find_last_of('.'); if (dotPos == std::string::npos) continue; std::string name = nameStr.substr(0, dotPos); - TensorType type = nameStr.substr(dotPos + 1) == "weight" - ? TensorType::WEIGHT - : TensorType::BIAS; + + TensorType type = getTensorType(nameStr.substr(dotPos + 1)); line = line.substr(commaPos + 1); @@ -173,6 +181,29 @@ void Model::loadWeights(const std::string& path) { wLayer->setBiases(values.data()); } + + Layers::BatchNorm2d* bnLayer = dynamic_cast(wLayer); + if (bnLayer == nullptr) { + continue; + } + + if (tensorInfo.type == TensorType::RUNNING_MEAN) { + if (bnLayer->getRunningMean().size() != values.size()) { + std::cerr << "Layer: " << tensorInfo.name << " has incorrect number of running mean values, expected " + << bnLayer->getRunningMean().size() << " but got " << values.size() << ", skipping" << std::endl; + continue; + } + bnLayer->setRunningMean(values.data()); + } else if (tensorInfo.type == TensorType::RUNNING_VAR) { + if (bnLayer->getRunningVar().size() != values.size()) { + std::cerr << "Layer: " << tensorInfo.name << " has incorrect number of running var values, expected " + << bnLayer->getRunningVar().size() << " but got " << values.size() << ", skipping" << std::endl; + continue; + } + bnLayer->setRunningVar(values.data()); + } + + } else { std::cerr << "Layer: " << tensorInfo.name << " does not exist, skipping" << std::endl;