diff --git a/examples/inception_v3/CMakeLists.txt b/examples/inception_v3/CMakeLists.txt index 692f7f6..26c6935 100644 --- a/examples/inception_v3/CMakeLists.txt +++ b/examples/inception_v3/CMakeLists.txt @@ -6,6 +6,7 @@ project(Inceptionv3 add_library(inception_v3_lib inception_modules.cpp + inception_utils.cpp ) find_library(CUDANet_LIBRARY NAMES CUDANet HINTS ${CMAKE_CURRENT_SOURCE_DIR}/../../build) diff --git a/examples/inception_v3/inception_utils.cpp b/examples/inception_v3/inception_utils.cpp new file mode 100644 index 0000000..b3fe16f --- /dev/null +++ b/examples/inception_v3/inception_utils.cpp @@ -0,0 +1,46 @@ +#include +#include +#include + +std::vector +readAndNormalizeImage(const std::string &imagePath, int resizeSize, int cropSize) { + // Read the image using OpenCV + cv::Mat image = cv::imread(imagePath, cv::IMREAD_COLOR); + // Convert the image from BGR to RGB + cv::cvtColor(image, image, cv::COLOR_BGR2RGB); + + // Calculate the scaling factor + double scale = std::max(static_cast(resizeSize) / image.cols, static_cast(resizeSize) / image.rows); + + // Resize the image + cv::Mat resized; + cv::resize(image, resized, cv::Size(), scale, scale, cv::INTER_AREA); + + // Calculate the cropping coordinates + int x = (resized.cols - cropSize) / 2; + int y = (resized.rows - cropSize) / 2; + + // Perform center cropping + cv::Rect roi(x, y, cropSize, cropSize); + image = resized(roi); + + // Normalize the image + image.convertTo(image, CV_32FC3, 1.0 / 255.0); + + cv::Mat mean(image.size(), CV_32FC3, cv::Scalar(0.485, 0.456, 0.406)); + cv::Mat std(image.size(), CV_32FC3, cv::Scalar(0.229, 0.224, 0.225)); + cv::subtract(image, mean, image); + cv::divide(image, std, image); + + // Convert the 3D image matrix to a 1D array of floats + std::vector imageData; + for (int c = 0; c < image.channels(); ++c) { + for (int i = 0; i < image.rows; ++i) { + for (int j = 0; j < image.cols; ++j) { + imageData.push_back(image.at(i, j)[c]); + } + } + } + + return imageData; +} \ No newline at end of file diff --git a/examples/inception_v3/inception_v3.cpp b/examples/inception_v3/inception_v3.cpp index 7617f98..5556861 100644 --- a/examples/inception_v3/inception_v3.cpp +++ b/examples/inception_v3/inception_v3.cpp @@ -3,34 +3,8 @@ #include #include #include +#include -std::vector -readAndNormalizeImage(const std::string &imagePath, int width, int height) { - // Read the image using OpenCV - cv::Mat image = cv::imread(imagePath, cv::IMREAD_COLOR); - - // Resize and normalize the image - cv::resize(image, image, cv::Size(width, height)); - image.convertTo(image, CV_32FC3, 1.0 / 255.0); - - // Normalize the image https://pytorch.org/hub/pytorch_vision_alexnet/ - cv::Mat mean(image.size(), CV_32FC3, cv::Scalar(0.485, 0.456, 0.406)); - cv::Mat std(image.size(), CV_32FC3, cv::Scalar(0.229, 0.224, 0.225)); - cv::subtract(image, mean, image); - cv::divide(image, std, image); - - // Convert the 3D image matrix to a 1D array of floats - std::vector imageData; - for (int c = 0; c < image.channels(); ++c) { - for (int i = 0; i < image.rows; ++i) { - for (int j = 0; j < image.cols; ++j) { - imageData.push_back(image.at(i, j)[c]); - } - } - } - - return imageData; -} int main(int argc, const char *const argv[]) { if (argc != 3) { @@ -55,12 +29,12 @@ int main(int argc, const char *const argv[]) { inception_v3->loadWeights(modelWeightsPath); std::vector imageData = - readAndNormalizeImage(imagePath, inputSize.first, inputSize.second); + readAndNormalizeImage(imagePath, inputSize.first, inputSize.first); // Print the size of the image data const float *output = inception_v3->predict(imageData.data()); - // Get max index + // Get max index int maxIndex = 0; for (int i = 0; i < outputSize; i++) { if (output[i] > output[maxIndex]) { diff --git a/examples/inception_v3/inception_v3.hpp b/examples/inception_v3/inception_v3.hpp index 768fb15..dae1fc8 100644 --- a/examples/inception_v3/inception_v3.hpp +++ b/examples/inception_v3/inception_v3.hpp @@ -2,6 +2,10 @@ #define INCEPTION_V3_HPP #include +#include + +std::vector +readAndNormalizeImage(const std::string &imagePath, int resizeSize, int cropSize); class BasicConv2d : public CUDANet::Module { public: