Add version to model bin format

This commit is contained in:
2024-04-23 19:53:48 +02:00
parent 1592f06121
commit 69111f6cb1
4 changed files with 16 additions and 21 deletions

View File

@@ -16,6 +16,7 @@ def print_cpp_vector(vector, name="expected"):
def export_model_weights(model: torch.nn.Module, filename):
with open(filename, 'wb') as f:
version = 1
header = ""
offset = 0
tensor_data = b""
@@ -33,7 +34,8 @@ def export_model_weights(model: torch.nn.Module, filename):
tensor_data += tensor_bytes
f.seek(0)
f.write(struct.pack('q', len(header)))
f.write(struct.pack('H', version))
f.write(struct.pack('Q', len(header)))
f.write(header.encode('utf-8'))
f.write(tensor_data)