Implement pytorch weights export

This commit is contained in:
2024-04-15 22:17:14 +02:00
parent 13b455e4ba
commit d8c50116e8
3 changed files with 32 additions and 0 deletions

View File

@@ -1,6 +1,8 @@
import torch
import utils
from export_model_weights import export_model_weights
class TestModel(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
@@ -121,3 +123,5 @@ if __name__ == "__main__":
]).reshape(2, 6, 6)
out = model(input)
utils.print_cpp_vector(out)
export_model_weights(model, "test/resources/model.bin")