mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-05 17:34:21 +00:00
Add version to model bin format
This commit is contained in:
21
README.md
21
README.md
@@ -4,20 +4,7 @@
|
|||||||
|
|
||||||
Convolutional Neural Network inference library running on CUDA.
|
Convolutional Neural Network inference library running on CUDA.
|
||||||
|
|
||||||
## Features
|
## Quickstart Guide
|
||||||
|
|
||||||
- [x] Input layer
|
|
||||||
- [x] Dense (fully-connected) layer
|
|
||||||
- [x] Conv2d layer
|
|
||||||
- [x] Max pooling
|
|
||||||
- [x] Average pooling
|
|
||||||
- [x] Concat layer
|
|
||||||
- [x] Sigmoid activation
|
|
||||||
- [x] ReLU activation
|
|
||||||
- [x] Softmax activation
|
|
||||||
- [x] Load weights from file
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
**requirements**
|
**requirements**
|
||||||
- [cmake](https://cmake.org/)
|
- [cmake](https://cmake.org/)
|
||||||
@@ -94,7 +81,7 @@ float* MyModel::predict(const float* input) {
|
|||||||
CUDANet uses format similar to safetensors to load weights and biases.
|
CUDANet uses format similar to safetensors to load weights and biases.
|
||||||
|
|
||||||
```
|
```
|
||||||
[int64 header size, header, tensor values]
|
[u_short version, u_int64 header size, header, tensor values]
|
||||||
```
|
```
|
||||||
|
|
||||||
where `header` is a csv format
|
where `header` is a csv format
|
||||||
@@ -103,6 +90,4 @@ where `header` is a csv format
|
|||||||
<tensor_name>,<tensor_size>,<tensor_offset>
|
<tensor_name>,<tensor_size>,<tensor_offset>
|
||||||
```
|
```
|
||||||
|
|
||||||
To load weights call `load_weights` function on Model object.
|
To load weights call `load_weights` function on Model object. To export weights from pytorch you can use the `export_model_weights` function from `tools/utils.py` script. Currently only float32 weights are supported
|
||||||
|
|
||||||
To export weights from pytorch you can use the `export_model_weights` function from `tools/utils.py` script
|
|
||||||
@@ -66,7 +66,15 @@ void Model::loadWeights(const std::string& path) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t headerSize;
|
u_short version;
|
||||||
|
file.read(reinterpret_cast<char*>(&version), sizeof(version));
|
||||||
|
|
||||||
|
if (version != 1) {
|
||||||
|
std::cerr << "Unsupported model version: " << version << std::endl;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
u_int64_t headerSize;
|
||||||
file.read(reinterpret_cast<char*>(&headerSize), sizeof(headerSize));
|
file.read(reinterpret_cast<char*>(&headerSize), sizeof(headerSize));
|
||||||
|
|
||||||
std::string header(headerSize, '\0');
|
std::string header(headerSize, '\0');
|
||||||
@@ -110,7 +118,7 @@ void Model::loadWeights(const std::string& path) {
|
|||||||
for (const auto& tensorInfo : tensorInfos) {
|
for (const auto& tensorInfo : tensorInfos) {
|
||||||
std::vector<float> values(tensorInfo.size);
|
std::vector<float> values(tensorInfo.size);
|
||||||
|
|
||||||
file.seekg(sizeof(int64_t) + header.size() + tensorInfo.offset);
|
file.seekg(sizeof(version) + sizeof(headerSize) + header.size() + tensorInfo.offset);
|
||||||
file.read(reinterpret_cast<char*>(values.data()), tensorInfo.size * sizeof(float));
|
file.read(reinterpret_cast<char*>(values.data()), tensorInfo.size * sizeof(float));
|
||||||
|
|
||||||
if (layerMap.find(tensorInfo.name) != layerMap.end()) {
|
if (layerMap.find(tensorInfo.name) != layerMap.end()) {
|
||||||
|
|||||||
Binary file not shown.
@@ -16,6 +16,7 @@ def print_cpp_vector(vector, name="expected"):
|
|||||||
def export_model_weights(model: torch.nn.Module, filename):
|
def export_model_weights(model: torch.nn.Module, filename):
|
||||||
with open(filename, 'wb') as f:
|
with open(filename, 'wb') as f:
|
||||||
|
|
||||||
|
version = 1
|
||||||
header = ""
|
header = ""
|
||||||
offset = 0
|
offset = 0
|
||||||
tensor_data = b""
|
tensor_data = b""
|
||||||
@@ -33,7 +34,8 @@ def export_model_weights(model: torch.nn.Module, filename):
|
|||||||
tensor_data += tensor_bytes
|
tensor_data += tensor_bytes
|
||||||
|
|
||||||
f.seek(0)
|
f.seek(0)
|
||||||
f.write(struct.pack('q', len(header)))
|
f.write(struct.pack('H', version))
|
||||||
|
f.write(struct.pack('Q', len(header)))
|
||||||
f.write(header.encode('utf-8'))
|
f.write(header.encode('utf-8'))
|
||||||
f.write(tensor_data)
|
f.write(tensor_data)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user