mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-05 17:34:21 +00:00
Fix bin file seek offset
This commit is contained in:
@@ -110,7 +110,7 @@ void Model::loadWeights(const std::string& path) {
|
||||
for (const auto& tensorInfo : tensorInfos) {
|
||||
std::vector<float> values(tensorInfo.size);
|
||||
|
||||
file.seekg(tensorInfo.offset);
|
||||
file.seekg(sizeof(int64_t) + header.size() + tensorInfo.offset);
|
||||
file.read(reinterpret_cast<char*>(values.data()), tensorInfo.size * sizeof(float));
|
||||
|
||||
if (layerMap.find(tensorInfo.name) != layerMap.end()) {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
import utils
|
||||
|
||||
from export_model_weights import export_model_weights
|
||||
from utils import export_model_weights, print_cpp_vector
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
@@ -88,7 +87,7 @@ if __name__ == "__main__":
|
||||
|
||||
print("Single test output:")
|
||||
out = model(input)
|
||||
utils.print_cpp_vector(out)
|
||||
print_cpp_vector(out)
|
||||
|
||||
print("Multiple predict test output 1:")
|
||||
input = torch.tensor([
|
||||
@@ -105,7 +104,7 @@ if __name__ == "__main__":
|
||||
0.46545, 0.88722
|
||||
]).reshape(2, 6, 6)
|
||||
out = model(input)
|
||||
utils.print_cpp_vector(out)
|
||||
print_cpp_vector(out)
|
||||
|
||||
print("Multiple predict test output 2:")
|
||||
input = torch.tensor([
|
||||
@@ -122,6 +121,6 @@ if __name__ == "__main__":
|
||||
0.84456, 0.44482
|
||||
]).reshape(2, 6, 6)
|
||||
out = model(input)
|
||||
utils.print_cpp_vector(out)
|
||||
print_cpp_vector(out)
|
||||
|
||||
export_model_weights(model, "test/resources/model.bin")
|
||||
|
||||
Reference in New Issue
Block a user