mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Implement pytorch weights export
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import torch
|
||||
import utils
|
||||
|
||||
from export_model_weights import export_model_weights
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
@@ -121,3 +123,5 @@ if __name__ == "__main__":
|
||||
]).reshape(2, 6, 6)
|
||||
out = model(input)
|
||||
utils.print_cpp_vector(out)
|
||||
|
||||
export_model_weights(model, "test/resources/model.bin")
|
||||
|
||||
Reference in New Issue
Block a user