mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 09:44:28 +00:00
Export pretrained alexnet
This commit is contained in:
@@ -1,28 +0,0 @@
|
||||
import torch
|
||||
import struct
|
||||
|
||||
def export_model_weights(model: torch.nn.Module, filename):
|
||||
with open(filename, 'wb') as f:
|
||||
|
||||
header = ""
|
||||
offset = 0
|
||||
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if 'weight' not in name and 'bias' not in name:
|
||||
continue
|
||||
|
||||
tensor_values = param.flatten().tolist()
|
||||
tensor_bytes = struct.pack('f' * len(tensor_values), *tensor_values)
|
||||
|
||||
tensor_size = param.numel()
|
||||
|
||||
header += f"{name},{tensor_size},{offset}\n"
|
||||
|
||||
offset += len(tensor_bytes)
|
||||
|
||||
f.write(tensor_bytes)
|
||||
|
||||
f.seek(0)
|
||||
f.write(struct.pack('q', len(header)))
|
||||
f.write(header.encode('utf-8'))
|
||||
@@ -1,7 +1,42 @@
|
||||
def print_cpp_vector(vector):
|
||||
print("std::vector<float> expected = {", end="")
|
||||
import torch
|
||||
import struct
|
||||
|
||||
|
||||
def print_cpp_vector(vector, name="expected"):
|
||||
print("std::vector<float> " + name + " = {", end="")
|
||||
for i in range(len(vector)):
|
||||
if i != 0:
|
||||
print(", ", end="")
|
||||
print(str(round(vector[i].item(), 5)) + "f", end="")
|
||||
print("};")
|
||||
|
||||
|
||||
def export_model_weights(model: torch.nn.Module, filename):
|
||||
with open(filename, 'wb') as f:
|
||||
|
||||
header = ""
|
||||
offset = 0
|
||||
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if 'weight' not in name and 'bias' not in name:
|
||||
continue
|
||||
|
||||
tensor_values = param.flatten().tolist()
|
||||
tensor_bytes = struct.pack('f' * len(tensor_values), *tensor_values)
|
||||
|
||||
tensor_size = param.numel()
|
||||
|
||||
header += f"{name},{tensor_size},{offset}\n"
|
||||
|
||||
offset += len(tensor_bytes)
|
||||
|
||||
f.write(tensor_bytes)
|
||||
|
||||
f.seek(0)
|
||||
f.write(struct.pack('q', len(header)))
|
||||
f.write(header.encode('utf-8'))
|
||||
|
||||
def print_model_parameters(model: torch.nn.Module):
|
||||
for name, param in model.named_parameters():
|
||||
print(name, param.numel())
|
||||
Reference in New Issue
Block a user