mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Add version to model bin format
This commit is contained in:
@@ -66,7 +66,15 @@ void Model::loadWeights(const std::string& path) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t headerSize;
|
||||
u_short version;
|
||||
file.read(reinterpret_cast<char*>(&version), sizeof(version));
|
||||
|
||||
if (version != 1) {
|
||||
std::cerr << "Unsupported model version: " << version << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
u_int64_t headerSize;
|
||||
file.read(reinterpret_cast<char*>(&headerSize), sizeof(headerSize));
|
||||
|
||||
std::string header(headerSize, '\0');
|
||||
@@ -110,7 +118,7 @@ void Model::loadWeights(const std::string& path) {
|
||||
for (const auto& tensorInfo : tensorInfos) {
|
||||
std::vector<float> values(tensorInfo.size);
|
||||
|
||||
file.seekg(sizeof(int64_t) + header.size() + tensorInfo.offset);
|
||||
file.seekg(sizeof(version) + sizeof(headerSize) + header.size() + tensorInfo.offset);
|
||||
file.read(reinterpret_cast<char*>(values.data()), tensorInfo.size * sizeof(float));
|
||||
|
||||
if (layerMap.find(tensorInfo.name) != layerMap.end()) {
|
||||
|
||||
Reference in New Issue
Block a user