Start implementing weights import

This commit is contained in:
2024-04-15 22:17:48 +02:00
parent d8c50116e8
commit f4ae45f867
3 changed files with 159 additions and 31 deletions

View File

@@ -1,5 +1,11 @@
#include "model.hpp"
#include <fstream>
#include <iostream>
#include <string>
#include <unordered_map>
#include <vector>
#include "input.cuh"
#include "layer.cuh"
@@ -25,7 +31,7 @@ Model::Model(const Model& other)
outputLayer = new Layers::Output(*other.outputLayer);
}
Model::~Model(){
Model::~Model() {
delete inputLayer;
delete outputLayer;
for (auto layer : layers) {
@@ -51,4 +57,71 @@ void Model::addLayer(const std::string& name, Layers::SequentialLayer* layer) {
if (wLayer != nullptr) {
layerMap[name] = wLayer;
}
}
void Model::loadWeights(const std::string& path) {
std::ifstream file(path, std::ios::binary);
if (!file.is_open()) {
std::cerr << "Failed to open file: " << path << std::endl;
return;
}
int64_t headerSize;
file.read(reinterpret_cast<char*>(&headerSize), sizeof(headerSize));
std::string header(headerSize, '\0');
file.read(&header[0], headerSize);
std::vector<TensorInfo> tensorInfos;
size_t pos = 0;
while (pos < header.size()) {
size_t nextPos = header.find('\n', pos);
if (nextPos == std::string::npos)
break;
std::string line = header.substr(pos, nextPos - pos);
pos = nextPos + 1;
size_t commaPos = line.find(',');
if (commaPos == std::string::npos)
continue;
// Parse tensor name into name and type
std::string nameStr = line.substr(0, commaPos);
size_t dotPos = nameStr.find('.');
if (dotPos == std::string::npos)
continue;
std::string name = nameStr.substr(0, dotPos);
TensorType type = nameStr.substr(dotPos + 1) == "w" ? TensorType::WEIGHT : TensorType::BIAS;
line = line.substr(commaPos + 1);
commaPos = line.find(',');
if (commaPos == std::string::npos)
continue;
int size = std::stoi(line.substr(0, commaPos));
int offset = std::stoi(line.substr(commaPos + 1));
tensorInfos.push_back({name, type, size, offset});
}
for (const auto& tensorInfo : tensorInfos) {
std::vector<float> values(tensorInfo.size);
file.seekg(tensorInfo.offset);
file.read(reinterpret_cast<char*>(values.data()), tensorInfo.size * sizeof(float));
if (layerMap.find(tensorInfo.name) != layerMap.end()) {
if (tensorInfo.type == TensorType::WEIGHT) {
layerMap[tensorInfo.name]->setWeights(values.data());
} else if (tensorInfo.type == TensorType::BIAS) {
layerMap[tensorInfo.name]->setBiases(values.data());
}
}
}
file.close();
}