mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-05 17:34:21 +00:00
Load running mean and var from weight file
This commit is contained in:
@@ -57,6 +57,12 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer {
|
|||||||
*/
|
*/
|
||||||
void setRunningMean(const float* running_mean_input);
|
void setRunningMean(const float* running_mean_input);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Get the Running Mean
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
std::vector<float> getRunningMean();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Set the Running Var
|
* @brief Set the Running Var
|
||||||
*
|
*
|
||||||
@@ -64,6 +70,12 @@ class BatchNorm2d : public WeightedLayer, public TwoDLayer {
|
|||||||
*/
|
*/
|
||||||
void setRunningVar(const float* running_mean_input);
|
void setRunningVar(const float* running_mean_input);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Get the Running Var
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
std::vector<float> getRunningVar();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Get output size
|
* @brief Get output size
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ namespace CUDANet {
|
|||||||
enum TensorType {
|
enum TensorType {
|
||||||
WEIGHT,
|
WEIGHT,
|
||||||
BIAS,
|
BIAS,
|
||||||
|
RUNNING_MEAN,
|
||||||
|
RUNNING_VAR
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TensorInfo {
|
struct TensorInfo {
|
||||||
|
|||||||
@@ -121,11 +121,19 @@ void BatchNorm2d::setRunningMean(const float* running_mean_input) {
|
|||||||
toCuda();
|
toCuda();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<float> BatchNorm2d::getRunningMean() {
|
||||||
|
return running_mean;
|
||||||
|
}
|
||||||
|
|
||||||
void BatchNorm2d::setRunningVar(const float* running_var_input) {
|
void BatchNorm2d::setRunningVar(const float* running_var_input) {
|
||||||
std::copy(running_var_input, running_var_input + inputChannels, running_var.begin());
|
std::copy(running_var_input, running_var_input + inputChannels, running_var.begin());
|
||||||
toCuda();
|
toCuda();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<float> BatchNorm2d::getRunningVar() {
|
||||||
|
return running_var;
|
||||||
|
}
|
||||||
|
|
||||||
void BatchNorm2d::toCuda() {
|
void BatchNorm2d::toCuda() {
|
||||||
CUDA_CHECK(cudaMemcpy(
|
CUDA_CHECK(cudaMemcpy(
|
||||||
d_weights, weights.data(), sizeof(float) * inputChannels,
|
d_weights, weights.data(), sizeof(float) * inputChannels,
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
|
|
||||||
#include "input.cuh"
|
#include "input.cuh"
|
||||||
#include "layer.cuh"
|
#include "layer.cuh"
|
||||||
|
#include "batch_norm.cuh"
|
||||||
|
|
||||||
using namespace CUDANet;
|
using namespace CUDANet;
|
||||||
|
|
||||||
@@ -91,6 +92,14 @@ void Model::loadWeights(const std::string& path) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto getTensorType = [](const std::string& typeStr) {
|
||||||
|
if (typeStr == "weight") return TensorType::WEIGHT;
|
||||||
|
if (typeStr == "bias") return TensorType::BIAS;
|
||||||
|
if (typeStr == "running_mean") return TensorType::RUNNING_MEAN;
|
||||||
|
if (typeStr == "running_var") return TensorType::RUNNING_VAR;
|
||||||
|
throw std::runtime_error("Unknown tensor type: " + typeStr);
|
||||||
|
};
|
||||||
|
|
||||||
u_int64_t headerSize;
|
u_int64_t headerSize;
|
||||||
file.read(reinterpret_cast<char*>(&headerSize), sizeof(headerSize));
|
file.read(reinterpret_cast<char*>(&headerSize), sizeof(headerSize));
|
||||||
|
|
||||||
@@ -115,9 +124,8 @@ void Model::loadWeights(const std::string& path) {
|
|||||||
size_t dotPos = nameStr.find_last_of('.');
|
size_t dotPos = nameStr.find_last_of('.');
|
||||||
if (dotPos == std::string::npos) continue;
|
if (dotPos == std::string::npos) continue;
|
||||||
std::string name = nameStr.substr(0, dotPos);
|
std::string name = nameStr.substr(0, dotPos);
|
||||||
TensorType type = nameStr.substr(dotPos + 1) == "weight"
|
|
||||||
? TensorType::WEIGHT
|
TensorType type = getTensorType(nameStr.substr(dotPos + 1));
|
||||||
: TensorType::BIAS;
|
|
||||||
|
|
||||||
line = line.substr(commaPos + 1);
|
line = line.substr(commaPos + 1);
|
||||||
|
|
||||||
@@ -173,6 +181,29 @@ void Model::loadWeights(const std::string& path) {
|
|||||||
|
|
||||||
wLayer->setBiases(values.data());
|
wLayer->setBiases(values.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Layers::BatchNorm2d* bnLayer = dynamic_cast<Layers::BatchNorm2d*>(wLayer);
|
||||||
|
if (bnLayer == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tensorInfo.type == TensorType::RUNNING_MEAN) {
|
||||||
|
if (bnLayer->getRunningMean().size() != values.size()) {
|
||||||
|
std::cerr << "Layer: " << tensorInfo.name << " has incorrect number of running mean values, expected "
|
||||||
|
<< bnLayer->getRunningMean().size() << " but got " << values.size() << ", skipping" << std::endl;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
bnLayer->setRunningMean(values.data());
|
||||||
|
} else if (tensorInfo.type == TensorType::RUNNING_VAR) {
|
||||||
|
if (bnLayer->getRunningVar().size() != values.size()) {
|
||||||
|
std::cerr << "Layer: " << tensorInfo.name << " has incorrect number of running var values, expected "
|
||||||
|
<< bnLayer->getRunningVar().size() << " but got " << values.size() << ", skipping" << std::endl;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
bnLayer->setRunningVar(values.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
std::cerr << "Layer: " << tensorInfo.name
|
std::cerr << "Layer: " << tensorInfo.name
|
||||||
<< " does not exist, skipping" << std::endl;
|
<< " does not exist, skipping" << std::endl;
|
||||||
|
|||||||
Reference in New Issue
Block a user