diff --git a/examples/inception_v3/inception_v3.cpp b/examples/inception_v3/inception_v3.cpp new file mode 100644 index 0000000..9c8e0db --- /dev/null +++ b/examples/inception_v3/inception_v3.cpp @@ -0,0 +1,48 @@ +#include +#include + +int main(int argc, const char *const argv[]) { + if (argc != 3) { + std::cerr << "Usage: " << argv[0] << " " + << std::endl; + return 1; // Return error code indicating incorrect usage + } + + std::cout << "Loading model..." << std::endl; +} + +class BasicConv2d : public CUDANet::Module { + public: + BasicConv2d( + const int inputSize, + const int inputChannels, + const int outputChannels, + const int kernelSize, + const int stride, + const int padding, + const std::string& prefix + ) + : inputSize(inputSize), + inputChannels(inputChannels), + outputChannels(outputChannels) { + // Create the convolution layer + CUDANet::Layers::Conv2d *conv = new CUDANet::Layers::Conv2d( + inputSize, inputChannels, kernelSize, stride, outputChannels, padding, CUDANet::Layers::ActivationType::NONE + ); + + int batchNormSize = conv->getOutputSize(); + + CUDANet::Layers::BatchNorm *batchNorm = + new CUDANet::Layers::BatchNorm( + batchNormSize, outputChannels, 1e-3f, CUDANet::Layers::ActivationType::RELU + ); + + addLayer(prefix + ".conv", conv); + addLayer(prefix + ".bn", batchNorm); + } + + private: + int inputSize; + int inputChannels; + int outputChannels; +}; diff --git a/examples/inception_v3/inception_v3.py b/examples/inception_v3/inception_v3.py new file mode 100644 index 0000000..0570af7 --- /dev/null +++ b/examples/inception_v3/inception_v3.py @@ -0,0 +1,17 @@ +import torch +import torchvision +import sys + +from torchsummary import summary + +inception = torchvision.models.inception_v3(weights=torchvision.models.Inception_V3_Weights.DEFAULT) +inception.eval() + +sys.path.append('../../tools') # Ugly hack +from utils import export_model_weights, print_model_parameters + +print_model_parameters(inception) # print layer names and number of parameters + +inception.cuda() + +summary(inception, (3, 299, 299)) \ No newline at end of file