mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Fix bin file seek offset
This commit is contained in:
@@ -110,7 +110,7 @@ void Model::loadWeights(const std::string& path) {
|
|||||||
for (const auto& tensorInfo : tensorInfos) {
|
for (const auto& tensorInfo : tensorInfos) {
|
||||||
std::vector<float> values(tensorInfo.size);
|
std::vector<float> values(tensorInfo.size);
|
||||||
|
|
||||||
file.seekg(tensorInfo.offset);
|
file.seekg(sizeof(int64_t) + header.size() + tensorInfo.offset);
|
||||||
file.read(reinterpret_cast<char*>(values.data()), tensorInfo.size * sizeof(float));
|
file.read(reinterpret_cast<char*>(values.data()), tensorInfo.size * sizeof(float));
|
||||||
|
|
||||||
if (layerMap.find(tensorInfo.name) != layerMap.end()) {
|
if (layerMap.find(tensorInfo.name) != layerMap.end()) {
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import utils
|
|
||||||
|
|
||||||
from export_model_weights import export_model_weights
|
from utils import export_model_weights, print_cpp_vector
|
||||||
|
|
||||||
class TestModel(torch.nn.Module):
|
class TestModel(torch.nn.Module):
|
||||||
|
|
||||||
@@ -88,7 +87,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
print("Single test output:")
|
print("Single test output:")
|
||||||
out = model(input)
|
out = model(input)
|
||||||
utils.print_cpp_vector(out)
|
print_cpp_vector(out)
|
||||||
|
|
||||||
print("Multiple predict test output 1:")
|
print("Multiple predict test output 1:")
|
||||||
input = torch.tensor([
|
input = torch.tensor([
|
||||||
@@ -105,7 +104,7 @@ if __name__ == "__main__":
|
|||||||
0.46545, 0.88722
|
0.46545, 0.88722
|
||||||
]).reshape(2, 6, 6)
|
]).reshape(2, 6, 6)
|
||||||
out = model(input)
|
out = model(input)
|
||||||
utils.print_cpp_vector(out)
|
print_cpp_vector(out)
|
||||||
|
|
||||||
print("Multiple predict test output 2:")
|
print("Multiple predict test output 2:")
|
||||||
input = torch.tensor([
|
input = torch.tensor([
|
||||||
@@ -122,6 +121,6 @@ if __name__ == "__main__":
|
|||||||
0.84456, 0.44482
|
0.84456, 0.44482
|
||||||
]).reshape(2, 6, 6)
|
]).reshape(2, 6, 6)
|
||||||
out = model(input)
|
out = model(input)
|
||||||
utils.print_cpp_vector(out)
|
print_cpp_vector(out)
|
||||||
|
|
||||||
export_model_weights(model, "test/resources/model.bin")
|
export_model_weights(model, "test/resources/model.bin")
|
||||||
|
|||||||
Reference in New Issue
Block a user