mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Fix model weights export
This commit is contained in:
@@ -43,6 +43,7 @@ float* Model::predict(const float* input) {
|
|||||||
float* d_input = inputLayer->forward(input);
|
float* d_input = inputLayer->forward(input);
|
||||||
|
|
||||||
for (auto& layer : layers) {
|
for (auto& layer : layers) {
|
||||||
|
std::cout << layer << std::endl;
|
||||||
d_input = layer->forward(d_input);
|
d_input = layer->forward(d_input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
@@ -1,6 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def print_cpp_vector(vector, name="expected"):
|
def print_cpp_vector(vector, name="expected"):
|
||||||
print("std::vector<float> " + name + " = {", end="")
|
print("std::vector<float> " + name + " = {", end="")
|
||||||
@@ -16,26 +18,24 @@ def export_model_weights(model: torch.nn.Module, filename):
|
|||||||
|
|
||||||
header = ""
|
header = ""
|
||||||
offset = 0
|
offset = 0
|
||||||
|
tensor_data = b""
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if 'weight' not in name and 'bias' not in name:
|
if 'weight' not in name and 'bias' not in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tensor_values = param.flatten().tolist()
|
tensor_bytes = param.type(torch.float32).detach().numpy().tobytes()
|
||||||
tensor_bytes = struct.pack('f' * len(tensor_values), *tensor_values)
|
|
||||||
|
|
||||||
tensor_size = param.numel()
|
tensor_size = param.numel()
|
||||||
|
|
||||||
header += f"{name},{tensor_size},{offset}\n"
|
header += f"{name},{tensor_size},{offset}\n"
|
||||||
|
|
||||||
offset += len(tensor_bytes)
|
offset += len(tensor_bytes)
|
||||||
|
|
||||||
f.write(tensor_bytes)
|
tensor_data += tensor_bytes
|
||||||
|
|
||||||
f.seek(0)
|
f.seek(0)
|
||||||
f.write(struct.pack('q', len(header)))
|
f.write(struct.pack('q', len(header)))
|
||||||
f.write(header.encode('utf-8'))
|
f.write(header.encode('utf-8'))
|
||||||
|
f.write(tensor_data)
|
||||||
|
|
||||||
def print_model_parameters(model: torch.nn.Module):
|
def print_model_parameters(model: torch.nn.Module):
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
|
|||||||
Reference in New Issue
Block a user