Implement simple model validation

This commit is contained in:
2024-04-22 20:57:40 +02:00
parent f17debc244
commit c2acad151b
3 changed files with 24 additions and 0 deletions

View File

@@ -109,6 +109,9 @@ int main(int argc, const char* const argv[]) {
const int outputSize = 1000; const int outputSize = 1000;
CUDANet::Model *model = createModel(inputSize, inputChannels, outputSize); CUDANet::Model *model = createModel(inputSize, inputChannels, outputSize);
model->validate();
model->loadWeights(modelWeightsPath); model->loadWeights(modelWeightsPath);
// Read and normalize the image // Read and normalize the image

View File

@@ -36,6 +36,8 @@ class Model {
void loadWeights(const std::string& path); void loadWeights(const std::string& path);
bool validate();
private: private:
Layers::Input* inputLayer; Layers::Input* inputLayer;
Layers::Output* outputLayer; Layers::Output* outputLayer;

View File

@@ -147,4 +147,23 @@ void Model::loadWeights(const std::string& path) {
} }
file.close(); file.close();
}
bool Model::validate() {
bool valid = true;
int size = inputLayer->getInputSize();
for (const auto& layer : layers) {
if (layer.second->getInputSize() != size) {
valid = false;
std::cerr << "Layer: " << layer.first << " has incorrect input size, expected " << size << " but got "
<< layer.second->getInputSize() << std::endl;
break;
}
size = layer.second->getOutputSize();
}
return valid;
} }