Fix bin file seek offset

This commit is contained in:
2024-04-20 21:30:01 +02:00
parent 5e663b9029
commit 9c5d853b75
2 changed files with 5 additions and 6 deletions

View File

@@ -110,7 +110,7 @@ void Model::loadWeights(const std::string& path) {
for (const auto& tensorInfo : tensorInfos) { for (const auto& tensorInfo : tensorInfos) {
std::vector<float> values(tensorInfo.size); std::vector<float> values(tensorInfo.size);
file.seekg(tensorInfo.offset); file.seekg(sizeof(int64_t) + header.size() + tensorInfo.offset);
file.read(reinterpret_cast<char*>(values.data()), tensorInfo.size * sizeof(float)); file.read(reinterpret_cast<char*>(values.data()), tensorInfo.size * sizeof(float));
if (layerMap.find(tensorInfo.name) != layerMap.end()) { if (layerMap.find(tensorInfo.name) != layerMap.end()) {

View File

@@ -1,7 +1,6 @@
import torch import torch
import utils
from export_model_weights import export_model_weights from utils import export_model_weights, print_cpp_vector
class TestModel(torch.nn.Module): class TestModel(torch.nn.Module):
@@ -88,7 +87,7 @@ if __name__ == "__main__":
print("Single test output:") print("Single test output:")
out = model(input) out = model(input)
utils.print_cpp_vector(out) print_cpp_vector(out)
print("Multiple predict test output 1:") print("Multiple predict test output 1:")
input = torch.tensor([ input = torch.tensor([
@@ -105,7 +104,7 @@ if __name__ == "__main__":
0.46545, 0.88722 0.46545, 0.88722
]).reshape(2, 6, 6) ]).reshape(2, 6, 6)
out = model(input) out = model(input)
utils.print_cpp_vector(out) print_cpp_vector(out)
print("Multiple predict test output 2:") print("Multiple predict test output 2:")
input = torch.tensor([ input = torch.tensor([
@@ -122,6 +121,6 @@ if __name__ == "__main__":
0.84456, 0.44482 0.84456, 0.44482
]).reshape(2, 6, 6) ]).reshape(2, 6, 6)
out = model(input) out = model(input)
utils.print_cpp_vector(out) print_cpp_vector(out)
export_model_weights(model, "test/resources/model.bin") export_model_weights(model, "test/resources/model.bin")