diff --git a/examples/inception_v3/inception_v3.cpp b/examples/inception_v3/inception_v3.cpp index 9c8e0db..5091df7 100644 --- a/examples/inception_v3/inception_v3.cpp +++ b/examples/inception_v3/inception_v3.cpp @@ -14,35 +14,91 @@ int main(int argc, const char *const argv[]) { class BasicConv2d : public CUDANet::Module { public: BasicConv2d( - const int inputSize, - const int inputChannels, - const int outputChannels, - const int kernelSize, - const int stride, - const int padding, - const std::string& prefix + const int inputSize, + const int inputChannels, + const int outputChannels, + const int kernelSize, + const int stride, + const int padding, + const std::string &prefix ) : inputSize(inputSize), inputChannels(inputChannels), outputChannels(outputChannels) { // Create the convolution layer CUDANet::Layers::Conv2d *conv = new CUDANet::Layers::Conv2d( - inputSize, inputChannels, kernelSize, stride, outputChannels, padding, CUDANet::Layers::ActivationType::NONE + inputSize, inputChannels, kernelSize, stride, outputChannels, + padding, CUDANet::Layers::ActivationType::NONE ); int batchNormSize = conv->getOutputSize(); - CUDANet::Layers::BatchNorm *batchNorm = - new CUDANet::Layers::BatchNorm( - batchNormSize, outputChannels, 1e-3f, CUDANet::Layers::ActivationType::RELU - ); + CUDANet::Layers::BatchNorm *batchNorm = new CUDANet::Layers::BatchNorm( + batchNormSize, outputChannels, 1e-3f, + CUDANet::Layers::ActivationType::RELU + ); addLayer(prefix + ".conv", conv); addLayer(prefix + ".bn", batchNorm); } + float* forward(const float* d_input) { + for (auto& layer : layers) { + d_input = layer.second->forward(d_input); + } + return d_input; + } + private: int inputSize; int inputChannels; int outputChannels; }; + +class InceptionA : public CUDANet::Module { + public: + InceptionA( + const int inputSize, + const int inputChannels, + const int poolFeatures, + const std::string &prefix + ) + : inputSize(inputSize), + inputChannels(inputChannels), + poolFeatures(poolFeatures) { + + // Branch 1x1 + CUDANet::Module *branch1x1 = new BasicConv2d( + inputSize, inputChannels, 64, 1, 1, 0, prefix + ".branch1x1" + ); + + // Branch 5x5 + CUDANet::Module *branch5x5_1 = new BasicConv2d( + inputSize, inputChannels, 48, 1, 1, 0, prefix + ".branch5x5_1" + ); + CUDANet::Module *branch5x5_2 = new BasicConv2d( + inputSize, 48, 64, 5, 1, 2, prefix + ".branch5x5_2" + ); + + // Branch 3x3 + CUDANet::Module *branch3x3_1 = new BasicConv2d( + inputSize, inputChannels, 64, 1, 1, 0, prefix + ".branch3x3_1" + ); + CUDANet::Module *branch3x3_2 = new BasicConv2d( + inputSize, 64, 96, 3, 1, 1, prefix + ".branch3x3_2" + ); + CUDANet::Module *branch3x3_3 = new BasicConv2d( + inputSize, 96, 96, 3, 1, 1, prefix + ".branch3x3_3" + ); + + // Branch Pool + CUDANet::Module *branchPool = new BasicConv2d( + inputSize, inputChannels, poolFeatures, 1, 1, 0, prefix + ".branchPool" + ); + } + + private: + int inputSize; + int inputChannels; + int poolFeatures; +}; \ No newline at end of file