From d64a28bc9c22a178dd4765e0b51404a0313e9b0b Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sun, 21 Apr 2024 00:05:56 +0200 Subject: [PATCH] Fix model weights export --- src/model/model.cpp | 1 + test/resources/model.bin | Bin 240 -> 287 bytes tools/utils.py | 18 +++++++++--------- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/model/model.cpp b/src/model/model.cpp index c18142a..07c721d 100644 --- a/src/model/model.cpp +++ b/src/model/model.cpp @@ -43,6 +43,7 @@ float* Model::predict(const float* input) { float* d_input = inputLayer->forward(input); for (auto& layer : layers) { + std::cout << layer << std::endl; d_input = layer->forward(d_input); } diff --git a/test/resources/model.bin b/test/resources/model.bin index d796950819b364fe94f74201bdb49df6736ddef7..d58f27a937d3aa1491333387c20f41be4a3cee24 100644 GIT binary patch delta 57 zcmV-90LK6D0iObpFE4I}J3jiG2tU+d%03$}jXnj%uRc+|)joAuEIuIb9X|lZ3qK|{ Pe?RkaML%7?fRQi3kQx@n delta 9 QcmbQw^nr1L{=~S0022ZP0ssI2 diff --git a/tools/utils.py b/tools/utils.py index 73dbb2c..6c4f689 100644 --- a/tools/utils.py +++ b/tools/utils.py @@ -1,6 +1,8 @@ import torch import struct +import numpy as np + def print_cpp_vector(vector, name="expected"): print("std::vector " + name + " = {", end="") @@ -16,27 +18,25 @@ def export_model_weights(model: torch.nn.Module, filename): header = "" offset = 0 - + tensor_data = b"" for name, param in model.named_parameters(): if 'weight' not in name and 'bias' not in name: continue - tensor_values = param.flatten().tolist() - tensor_bytes = struct.pack('f' * len(tensor_values), *tensor_values) - + tensor_bytes = param.type(torch.float32).detach().numpy().tobytes() tensor_size = param.numel() - header += f"{name},{tensor_size},{offset}\n" - + header += f"{name},{tensor_size},{offset}\n" offset += len(tensor_bytes) - f.write(tensor_bytes) + tensor_data += tensor_bytes f.seek(0) - f.write(struct.pack('q', len(header))) + f.write(struct.pack('q', len(header))) f.write(header.encode('utf-8')) + f.write(tensor_data) def print_model_parameters(model: torch.nn.Module): for name, param in model.named_parameters(): - print(name, param.numel()) \ No newline at end of file + print(name, param.numel())