mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Fix weigh bias parsing and better error logging
This commit is contained in:
@@ -89,11 +89,11 @@ void Model::loadWeights(const std::string& path) {
|
||||
|
||||
// Parse tensor name into name and type
|
||||
std::string nameStr = line.substr(0, commaPos);
|
||||
size_t dotPos = nameStr.find('.');
|
||||
size_t dotPos = nameStr.find_last_of('.');
|
||||
if (dotPos == std::string::npos)
|
||||
continue;
|
||||
std::string name = nameStr.substr(0, dotPos);
|
||||
TensorType type = nameStr.substr(dotPos + 1) == "w" ? TensorType::WEIGHT : TensorType::BIAS;
|
||||
TensorType type = nameStr.substr(dotPos + 1) == "weight" ? TensorType::WEIGHT : TensorType::BIAS;
|
||||
|
||||
line = line.substr(commaPos + 1);
|
||||
|
||||
@@ -118,15 +118,31 @@ void Model::loadWeights(const std::string& path) {
|
||||
Layers::WeightedLayer* wLayer = dynamic_cast<Layers::WeightedLayer*>(layerMap[tensorInfo.name]);
|
||||
|
||||
if (wLayer == nullptr) {
|
||||
std::cerr << "Layer: " << tensorInfo.name << "does not have weights, skipping" << std::endl;
|
||||
std::cerr << "Layer: " << tensorInfo.name << " does not have weights" << std::endl;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (tensorInfo.type == TensorType::WEIGHT) {
|
||||
|
||||
if (wLayer->getWeights().size() != values.size()) {
|
||||
std::cerr << "Layer: " << tensorInfo.name << " has incorrect number of weights, expected "
|
||||
<< wLayer->getWeights().size() << " but got " << values.size() << ", skipping" << std::endl;
|
||||
continue;
|
||||
}
|
||||
|
||||
wLayer->setWeights(values.data());
|
||||
} else if (tensorInfo.type == TensorType::BIAS) {
|
||||
|
||||
if (wLayer->getBiases().size() != values.size()) {
|
||||
std::cerr << "Layer: " << tensorInfo.name << " has incorrect number of biases, expected "
|
||||
<< wLayer->getBiases().size() << " but got " << values.size() << ", skipping" << std::endl;
|
||||
continue;
|
||||
}
|
||||
|
||||
wLayer->setBiases(values.data());
|
||||
}
|
||||
} else {
|
||||
std::cerr << "Layer: " << tensorInfo.name << " does not exist, skipping" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user