mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-05 17:34:21 +00:00
Implement pytorch weights export
This commit is contained in:
BIN
test/resources/model.bin
Normal file
BIN
test/resources/model.bin
Normal file
Binary file not shown.
28
tools/export_model_weights.py
Normal file
28
tools/export_model_weights.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import torch
|
||||||
|
import struct
|
||||||
|
|
||||||
|
def export_model_weights(model: torch.nn.Module, filename):
|
||||||
|
with open(filename, 'wb') as f:
|
||||||
|
|
||||||
|
header = ""
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if 'weight' not in name and 'bias' not in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tensor_values = param.flatten().tolist()
|
||||||
|
tensor_bytes = struct.pack('f' * len(tensor_values), *tensor_values)
|
||||||
|
|
||||||
|
tensor_size = param.numel()
|
||||||
|
|
||||||
|
header += f"{name},{tensor_size},{offset}\n"
|
||||||
|
|
||||||
|
offset += len(tensor_bytes)
|
||||||
|
|
||||||
|
f.write(tensor_bytes)
|
||||||
|
|
||||||
|
f.seek(0)
|
||||||
|
f.write(struct.pack('q', len(header)))
|
||||||
|
f.write(header.encode('utf-8'))
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import utils
|
import utils
|
||||||
|
|
||||||
|
from export_model_weights import export_model_weights
|
||||||
|
|
||||||
class TestModel(torch.nn.Module):
|
class TestModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
@@ -121,3 +123,5 @@ if __name__ == "__main__":
|
|||||||
]).reshape(2, 6, 6)
|
]).reshape(2, 6, 6)
|
||||||
out = model(input)
|
out = model(input)
|
||||||
utils.print_cpp_vector(out)
|
utils.print_cpp_vector(out)
|
||||||
|
|
||||||
|
export_model_weights(model, "test/resources/model.bin")
|
||||||
|
|||||||
Reference in New Issue
Block a user