From c2acad151b97f33cb4216a222d6bac4abf3f8ff6 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Mon, 22 Apr 2024 20:57:40 +0200 Subject: [PATCH] Implement simple model validation --- examples/alexnet/main.cpp | 3 +++ include/model/model.hpp | 2 ++ src/model/model.cpp | 19 +++++++++++++++++++ 3 files changed, 24 insertions(+) diff --git a/examples/alexnet/main.cpp b/examples/alexnet/main.cpp index 1e88b9d..e133f65 100644 --- a/examples/alexnet/main.cpp +++ b/examples/alexnet/main.cpp @@ -109,6 +109,9 @@ int main(int argc, const char* const argv[]) { const int outputSize = 1000; CUDANet::Model *model = createModel(inputSize, inputChannels, outputSize); + + model->validate(); + model->loadWeights(modelWeightsPath); // Read and normalize the image diff --git a/include/model/model.hpp b/include/model/model.hpp index 126c7a5..cdb1158 100644 --- a/include/model/model.hpp +++ b/include/model/model.hpp @@ -36,6 +36,8 @@ class Model { void loadWeights(const std::string& path); + bool validate(); + private: Layers::Input* inputLayer; Layers::Output* outputLayer; diff --git a/src/model/model.cpp b/src/model/model.cpp index 8b138c9..a8530f6 100644 --- a/src/model/model.cpp +++ b/src/model/model.cpp @@ -147,4 +147,23 @@ void Model::loadWeights(const std::string& path) { } file.close(); +} + +bool Model::validate() { + + bool valid = true; + int size = inputLayer->getInputSize(); + + for (const auto& layer : layers) { + if (layer.second->getInputSize() != size) { + valid = false; + std::cerr << "Layer: " << layer.first << " has incorrect input size, expected " << size << " but got " + << layer.second->getInputSize() << std::endl; + break; + } + + size = layer.second->getOutputSize(); + } + + return valid; } \ No newline at end of file