Load running mean and var from weight file

This commit is contained in:
2024-08-25 19:33:33 +02:00
parent 9704d0d53e
commit bc9bff10cd
4 changed files with 56 additions and 3 deletions

View File

@@ -57,6 +57,12 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer {
*/ */
void setRunningMean(const float* running_mean_input); void setRunningMean(const float* running_mean_input);
/**
* @brief Get the Running Mean
*
*/
std::vector<float> getRunningMean();
/** /**
* @brief Set the Running Var * @brief Set the Running Var
* *
@@ -64,6 +70,12 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer {
*/ */
void setRunningVar(const float* running_mean_input); void setRunningVar(const float* running_mean_input);
/**
* @brief Get the Running Var
*
*/
std::vector<float> getRunningVar();
/** /**
* @brief Get output size * @brief Get output size
* *

View File

@@ -15,6 +15,8 @@ namespace CUDANet {
enum TensorType { enum TensorType {
WEIGHT, WEIGHT,
BIAS, BIAS,
RUNNING_MEAN,
RUNNING_VAR
}; };
struct TensorInfo { struct TensorInfo {

View File

@@ -121,11 +121,19 @@ void BatchNorm2d::setRunningMean(const float* running_mean_input) {
toCuda(); toCuda();
} }
std::vector<float> BatchNorm2d::getRunningMean() {
return running_mean;
}
void BatchNorm2d::setRunningVar(const float* running_var_input) { void BatchNorm2d::setRunningVar(const float* running_var_input) {
std::copy(running_var_input, running_var_input + inputChannels, running_var.begin()); std::copy(running_var_input, running_var_input + inputChannels, running_var.begin());
toCuda(); toCuda();
} }
std::vector<float> BatchNorm2d::getRunningVar() {
return running_var;
}
void BatchNorm2d::toCuda() { void BatchNorm2d::toCuda() {
CUDA_CHECK(cudaMemcpy( CUDA_CHECK(cudaMemcpy(
d_weights, weights.data(), sizeof(float) * inputChannels, d_weights, weights.data(), sizeof(float) * inputChannels,

View File

@@ -9,6 +9,7 @@
#include "input.cuh" #include "input.cuh"
#include "layer.cuh" #include "layer.cuh"
#include "batch_norm.cuh"
using namespace CUDANet; using namespace CUDANet;
@@ -91,6 +92,14 @@ void Model::loadWeights(const std::string& path) {
return; return;
} }
auto getTensorType = [](const std::string& typeStr) {
if (typeStr == "weight") return TensorType::WEIGHT;
if (typeStr == "bias") return TensorType::BIAS;
if (typeStr == "running_mean") return TensorType::RUNNING_MEAN;
if (typeStr == "running_var") return TensorType::RUNNING_VAR;
throw std::runtime_error("Unknown tensor type: " + typeStr);
};
u_int64_t headerSize; u_int64_t headerSize;
file.read(reinterpret_cast<char*>(&headerSize), sizeof(headerSize)); file.read(reinterpret_cast<char*>(&headerSize), sizeof(headerSize));
@@ -115,9 +124,8 @@ void Model::loadWeights(const std::string& path) {
size_t dotPos = nameStr.find_last_of('.'); size_t dotPos = nameStr.find_last_of('.');
if (dotPos == std::string::npos) continue; if (dotPos == std::string::npos) continue;
std::string name = nameStr.substr(0, dotPos); std::string name = nameStr.substr(0, dotPos);
TensorType type = nameStr.substr(dotPos + 1) == "weight"
? TensorType::WEIGHT TensorType type = getTensorType(nameStr.substr(dotPos + 1));
: TensorType::BIAS;
line = line.substr(commaPos + 1); line = line.substr(commaPos + 1);
@@ -173,6 +181,29 @@ void Model::loadWeights(const std::string& path) {
wLayer->setBiases(values.data()); wLayer->setBiases(values.data());
} }
Layers::BatchNorm2d* bnLayer = dynamic_cast<Layers::BatchNorm2d*>(wLayer);
if (bnLayer == nullptr) {
continue;
}
if (tensorInfo.type == TensorType::RUNNING_MEAN) {
if (bnLayer->getRunningMean().size() != values.size()) {
std::cerr << "Layer: " << tensorInfo.name << " has incorrect number of running mean values, expected "
<< bnLayer->getRunningMean().size() << " but got " << values.size() << ", skipping" << std::endl;
continue;
}
bnLayer->setRunningMean(values.data());
} else if (tensorInfo.type == TensorType::RUNNING_VAR) {
if (bnLayer->getRunningVar().size() != values.size()) {
std::cerr << "Layer: " << tensorInfo.name << " has incorrect number of running var values, expected "
<< bnLayer->getRunningVar().size() << " but got " << values.size() << ", skipping" << std::endl;
continue;
}
bnLayer->setRunningVar(values.data());
}
} else { } else {
std::cerr << "Layer: " << tensorInfo.name std::cerr << "Layer: " << tensorInfo.name
<< " does not exist, skipping" << std::endl; << " does not exist, skipping" << std::endl;