mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 09:44:28 +00:00
Start implementing weights import
This commit is contained in:
@@ -1,5 +1,11 @@
|
||||
#include "model.hpp"
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "input.cuh"
|
||||
#include "layer.cuh"
|
||||
|
||||
@@ -25,7 +31,7 @@ Model::Model(const Model& other)
|
||||
outputLayer = new Layers::Output(*other.outputLayer);
|
||||
}
|
||||
|
||||
Model::~Model(){
|
||||
Model::~Model() {
|
||||
delete inputLayer;
|
||||
delete outputLayer;
|
||||
for (auto layer : layers) {
|
||||
@@ -51,4 +57,71 @@ void Model::addLayer(const std::string& name, Layers::SequentialLayer* layer) {
|
||||
if (wLayer != nullptr) {
|
||||
layerMap[name] = wLayer;
|
||||
}
|
||||
}
|
||||
|
||||
void Model::loadWeights(const std::string& path) {
|
||||
std::ifstream file(path, std::ios::binary);
|
||||
|
||||
if (!file.is_open()) {
|
||||
std::cerr << "Failed to open file: " << path << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t headerSize;
|
||||
file.read(reinterpret_cast<char*>(&headerSize), sizeof(headerSize));
|
||||
|
||||
std::string header(headerSize, '\0');
|
||||
file.read(&header[0], headerSize);
|
||||
|
||||
std::vector<TensorInfo> tensorInfos;
|
||||
size_t pos = 0;
|
||||
|
||||
while (pos < header.size()) {
|
||||
size_t nextPos = header.find('\n', pos);
|
||||
if (nextPos == std::string::npos)
|
||||
break;
|
||||
|
||||
std::string line = header.substr(pos, nextPos - pos);
|
||||
pos = nextPos + 1;
|
||||
|
||||
size_t commaPos = line.find(',');
|
||||
if (commaPos == std::string::npos)
|
||||
continue;
|
||||
|
||||
// Parse tensor name into name and type
|
||||
std::string nameStr = line.substr(0, commaPos);
|
||||
size_t dotPos = nameStr.find('.');
|
||||
if (dotPos == std::string::npos)
|
||||
continue;
|
||||
std::string name = nameStr.substr(0, dotPos);
|
||||
TensorType type = nameStr.substr(dotPos + 1) == "w" ? TensorType::WEIGHT : TensorType::BIAS;
|
||||
|
||||
line = line.substr(commaPos + 1);
|
||||
|
||||
commaPos = line.find(',');
|
||||
if (commaPos == std::string::npos)
|
||||
continue;
|
||||
|
||||
int size = std::stoi(line.substr(0, commaPos));
|
||||
int offset = std::stoi(line.substr(commaPos + 1));
|
||||
|
||||
tensorInfos.push_back({name, type, size, offset});
|
||||
}
|
||||
|
||||
for (const auto& tensorInfo : tensorInfos) {
|
||||
std::vector<float> values(tensorInfo.size);
|
||||
|
||||
file.seekg(tensorInfo.offset);
|
||||
file.read(reinterpret_cast<char*>(values.data()), tensorInfo.size * sizeof(float));
|
||||
|
||||
if (layerMap.find(tensorInfo.name) != layerMap.end()) {
|
||||
if (tensorInfo.type == TensorType::WEIGHT) {
|
||||
layerMap[tensorInfo.name]->setWeights(values.data());
|
||||
} else if (tensorInfo.type == TensorType::BIAS) {
|
||||
layerMap[tensorInfo.name]->setBiases(values.data());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
file.close();
|
||||
}
|
||||
Reference in New Issue
Block a user