Fix weigh bias parsing and better error logging

This commit is contained in:
2024-04-20 18:36:53 +02:00
parent ecf7416f8e
commit d08567a563
6 changed files with 93 additions and 11 deletions

View File

@@ -89,11 +89,11 @@ void Model::loadWeights(const std::string& path) {
// Parse tensor name into name and type
std::string nameStr = line.substr(0, commaPos);
size_t dotPos = nameStr.find('.');
size_t dotPos = nameStr.find_last_of('.');
if (dotPos == std::string::npos)
continue;
std::string name = nameStr.substr(0, dotPos);
TensorType type = nameStr.substr(dotPos + 1) == "w" ? TensorType::WEIGHT : TensorType::BIAS;
TensorType type = nameStr.substr(dotPos + 1) == "weight" ? TensorType::WEIGHT : TensorType::BIAS;
line = line.substr(commaPos + 1);
@@ -118,15 +118,31 @@ void Model::loadWeights(const std::string& path) {
Layers::WeightedLayer* wLayer = dynamic_cast<Layers::WeightedLayer*>(layerMap[tensorInfo.name]);
if (wLayer == nullptr) {
std::cerr << "Layer: " << tensorInfo.name << "does not have weights, skipping" << std::endl;
std::cerr << "Layer: " << tensorInfo.name << " does not have weights" << std::endl;
continue;
}
if (tensorInfo.type == TensorType::WEIGHT) {
if (wLayer->getWeights().size() != values.size()) {
std::cerr << "Layer: " << tensorInfo.name << " has incorrect number of weights, expected "
<< wLayer->getWeights().size() << " but got " << values.size() << ", skipping" << std::endl;
continue;
}
wLayer->setWeights(values.data());
} else if (tensorInfo.type == TensorType::BIAS) {
if (wLayer->getBiases().size() != values.size()) {
std::cerr << "Layer: " << tensorInfo.name << " has incorrect number of biases, expected "
<< wLayer->getBiases().size() << " but got " << values.size() << ", skipping" << std::endl;
continue;
}
wLayer->setBiases(values.data());
}
} else {
std::cerr << "Layer: " << tensorInfo.name << " does not exist, skipping" << std::endl;
}
}