mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Add version to model bin format
This commit is contained in:
@@ -16,6 +16,7 @@ def print_cpp_vector(vector, name="expected"):
|
||||
def export_model_weights(model: torch.nn.Module, filename):
|
||||
with open(filename, 'wb') as f:
|
||||
|
||||
version = 1
|
||||
header = ""
|
||||
offset = 0
|
||||
tensor_data = b""
|
||||
@@ -33,7 +34,8 @@ def export_model_weights(model: torch.nn.Module, filename):
|
||||
tensor_data += tensor_bytes
|
||||
|
||||
f.seek(0)
|
||||
f.write(struct.pack('q', len(header)))
|
||||
f.write(struct.pack('H', version))
|
||||
f.write(struct.pack('Q', len(header)))
|
||||
f.write(header.encode('utf-8'))
|
||||
f.write(tensor_data)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user