mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 17:54:27 +00:00
Implement simple model validation
This commit is contained in:
@@ -109,6 +109,9 @@ int main(int argc, const char* const argv[]) {
|
|||||||
const int outputSize = 1000;
|
const int outputSize = 1000;
|
||||||
|
|
||||||
CUDANet::Model *model = createModel(inputSize, inputChannels, outputSize);
|
CUDANet::Model *model = createModel(inputSize, inputChannels, outputSize);
|
||||||
|
|
||||||
|
model->validate();
|
||||||
|
|
||||||
model->loadWeights(modelWeightsPath);
|
model->loadWeights(modelWeightsPath);
|
||||||
|
|
||||||
// Read and normalize the image
|
// Read and normalize the image
|
||||||
|
|||||||
@@ -36,6 +36,8 @@ class Model {
|
|||||||
|
|
||||||
void loadWeights(const std::string& path);
|
void loadWeights(const std::string& path);
|
||||||
|
|
||||||
|
bool validate();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Layers::Input* inputLayer;
|
Layers::Input* inputLayer;
|
||||||
Layers::Output* outputLayer;
|
Layers::Output* outputLayer;
|
||||||
|
|||||||
@@ -148,3 +148,22 @@ void Model::loadWeights(const std::string& path) {
|
|||||||
|
|
||||||
file.close();
|
file.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Model::validate() {
|
||||||
|
|
||||||
|
bool valid = true;
|
||||||
|
int size = inputLayer->getInputSize();
|
||||||
|
|
||||||
|
for (const auto& layer : layers) {
|
||||||
|
if (layer.second->getInputSize() != size) {
|
||||||
|
valid = false;
|
||||||
|
std::cerr << "Layer: " << layer.first << " has incorrect input size, expected " << size << " but got "
|
||||||
|
<< layer.second->getInputSize() << std::endl;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
size = layer.second->getOutputSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
return valid;
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user