diff --git a/examples/alexnet/alexnet.cpp b/examples/alexnet/alexnet.cpp index e28787e..65effff 100644 --- a/examples/alexnet/alexnet.cpp +++ b/examples/alexnet/alexnet.cpp @@ -4,15 +4,30 @@ #include std::vector -readAndNormalizeImage(const std::string &imagePath, int width, int height) { +readAndNormalizeImage(const std::string &imagePath, int resizeSize, int cropSize) { // Read the image using OpenCV cv::Mat image = cv::imread(imagePath, cv::IMREAD_COLOR); + // Convert the image from BGR to RGB + cv::cvtColor(image, image, cv::COLOR_BGR2RGB); - // Resize and normalize the image - cv::resize(image, image, cv::Size(width, height)); + // Calculate the scaling factor + double scale = std::max(static_cast(resizeSize) / image.cols, static_cast(resizeSize) / image.rows); + + // Resize the image + cv::Mat resized; + cv::resize(image, resized, cv::Size(), scale, scale, cv::INTER_AREA); + + // Calculate the cropping coordinates + int x = (resized.cols - cropSize) / 2; + int y = (resized.rows - cropSize) / 2; + + // Perform center cropping + cv::Rect roi(x, y, cropSize, cropSize); + image = resized(roi); + + // Normalize the image image.convertTo(image, CV_32FC3, 1.0 / 255.0); - // Normalize the image https://pytorch.org/hub/pytorch_vision_alexnet/ cv::Mat mean(image.size(), CV_32FC3, cv::Scalar(0.485, 0.456, 0.406)); cv::Mat std(image.size(), CV_32FC3, cv::Scalar(0.229, 0.224, 0.225)); cv::subtract(image, mean, image); @@ -124,7 +139,7 @@ int main(int argc, const char *const argv[]) { // Read and normalize the image std::vector imageData = - readAndNormalizeImage(imagePath, inputSize.first, inputSize.second); + readAndNormalizeImage(imagePath, inputSize.first, inputSize.first); // Print the size of the image data const float *output = model->predict(imageData.data()); diff --git a/examples/alexnet/alexnet.py b/examples/alexnet/alexnet.py index bcd0a36..73c2a68 100644 --- a/examples/alexnet/alexnet.py +++ b/examples/alexnet/alexnet.py @@ -3,11 +3,17 @@ import sys import torchvision sys.path.append('../../tools') # Ugly hack -from utils import export_model_weights, print_model_parameters +from utils import export_model_weights, print_model_parameters, predict if __name__ == "__main__": - alexnet = torchvision.models.alexnet(weights=torchvision.models.AlexNet_Weights.DEFAULT) - print_model_parameters(alexnet) # print layer names and number of parameters - export_model_weights(alexnet, 'alexnet_weights.bin') - # predict(alexnet, 'cat.jpg') + + weights = torchvision.models.AlexNet_Weights.DEFAULT + alexnet = torchvision.models.alexnet(weights=weights) + + # print_model_parameters(alexnet) # print layer names and number of parameters + export_model_weights(alexnet, 'alexnet_weights.bin') + + # class_labels = weights.meta["categories"] + # prediction = predict(alexnet, "margot.jpg") + # print(prediction, class_labels[prediction])