Add version to model bin format

This commit is contained in:
2024-04-23 19:53:48 +02:00
parent 1592f06121
commit 69111f6cb1
4 changed files with 16 additions and 21 deletions

View File

@@ -66,7 +66,15 @@ void Model::loadWeights(const std::string& path) {
return;
}
int64_t headerSize;
u_short version;
file.read(reinterpret_cast<char*>(&version), sizeof(version));
if (version != 1) {
std::cerr << "Unsupported model version: " << version << std::endl;
return;
}
u_int64_t headerSize;
file.read(reinterpret_cast<char*>(&headerSize), sizeof(headerSize));
std::string header(headerSize, '\0');
@@ -110,7 +118,7 @@ void Model::loadWeights(const std::string& path) {
for (const auto& tensorInfo : tensorInfos) {
std::vector<float> values(tensorInfo.size);
file.seekg(sizeof(int64_t) + header.size() + tensorInfo.offset);
file.seekg(sizeof(version) + sizeof(headerSize) + header.size() + tensorInfo.offset);
file.read(reinterpret_cast<char*>(values.data()), tensorInfo.size * sizeof(float));
if (layerMap.find(tensorInfo.name) != layerMap.end()) {