Format inception

This commit is contained in:
2024-05-26 20:06:11 +02:00
parent ca9c4619e9
commit 07d505a0e5

View File

@@ -276,16 +276,16 @@ class InceptionC : public CUDANet::Module {
"branch7x7dbl_1" "branch7x7dbl_1"
); );
branch7x7dbl_2 = new BasicConv2d( branch7x7dbl_2 = new BasicConv2d(
branch7x7dbl_1->getOutputDims(), nChannels_7x7, nChannels_7x7, {7, 1}, branch7x7dbl_1->getOutputDims(), nChannels_7x7, nChannels_7x7,
{1, 1}, {3, 0}, "branch7x7dbl_2" {7, 1}, {1, 1}, {3, 0}, "branch7x7dbl_2"
); );
branch7x7dbl_3 = new BasicConv2d( branch7x7dbl_3 = new BasicConv2d(
branch7x7dbl_2->getOutputDims(), nChannels_7x7, nChannels_7x7, {1, 7}, branch7x7dbl_2->getOutputDims(), nChannels_7x7, nChannels_7x7,
{1, 1}, {0, 3}, "branch7x7dbl_3" {1, 7}, {1, 1}, {0, 3}, "branch7x7dbl_3"
); );
branch7x7dbl_4 = new BasicConv2d( branch7x7dbl_4 = new BasicConv2d(
branch7x7dbl_3->getOutputDims(), nChannels_7x7, nChannels_7x7, {7, 1}, branch7x7dbl_3->getOutputDims(), nChannels_7x7, nChannels_7x7,
{1, 1}, {3, 0}, "branch7x7dbl_4" {7, 1}, {1, 1}, {3, 0}, "branch7x7dbl_4"
); );
branch7x7dbl_5 = new BasicConv2d( branch7x7dbl_5 = new BasicConv2d(
branch7x7dbl_4->getOutputDims(), nChannels_7x7, 192, {1, 7}, {1, 1}, branch7x7dbl_4->getOutputDims(), nChannels_7x7, 192, {1, 7}, {1, 1},
@@ -294,7 +294,8 @@ class InceptionC : public CUDANet::Module {
// Branch Pool // Branch Pool
branchPool_1 = new CUDANet::Layers::AvgPooling2d( branchPool_1 = new CUDANet::Layers::AvgPooling2d(
inputSize, inputChannels, {3, 3}, {1, 1}, {1, 1}, CUDANet::Layers::ActivationType::NONE inputSize, inputChannels, {3, 3}, {1, 1}, {1, 1},
CUDANet::Layers::ActivationType::NONE
); );
branchPool_2 = new BasicConv2d( branchPool_2 = new BasicConv2d(
branchPool_1->getOutputDims(), inputChannels, 192, {1, 1}, {1, 1}, branchPool_1->getOutputDims(), inputChannels, 192, {1, 1}, {1, 1},
@@ -303,26 +304,22 @@ class InceptionC : public CUDANet::Module {
// Concat // Concat
concat_1 = new CUDANet::Layers::Concat( concat_1 = new CUDANet::Layers::Concat(
branch1x1->getOutputSize(), branch1x1->getOutputSize(), branch7x7_3->getOutputSize()
branch7x7_3->getOutputSize()
); );
concat_2 = new CUDANet::Layers::Concat( concat_2 = new CUDANet::Layers::Concat(
concat_1->getOutputSize(), concat_1->getOutputSize(), branch7x7dbl_5->getOutputSize()
branch7x7dbl_5->getOutputSize()
); );
concat_3 = new CUDANet::Layers::Concat( concat_3 = new CUDANet::Layers::Concat(
concat_2->getOutputSize(), concat_2->getOutputSize(), branchPool_2->getOutputSize()
branchPool_2->getOutputSize()
); );
} }
float *forward(const float *d_input) { float *forward(const float *d_input) {
float *branch1x1_output = branch1x1->forward(d_input); float *branch1x1_output = branch1x1->forward(d_input);
float *branch7x7_output = branch7x7_1->forward(d_input); float *branch7x7_output = branch7x7_1->forward(d_input);
branch7x7_output = branch7x7_2->forward(branch7x7_output); branch7x7_output = branch7x7_2->forward(branch7x7_output);
branch7x7_output = branch7x7_3->forward(branch7x7_output); branch7x7_output = branch7x7_3->forward(branch7x7_output);
float *branch7x7dbl_output = branch7x7dbl_1->forward(d_input); float *branch7x7dbl_output = branch7x7dbl_1->forward(d_input);
branch7x7dbl_output = branch7x7dbl_2->forward(branch7x7dbl_output); branch7x7dbl_output = branch7x7dbl_2->forward(branch7x7dbl_output);
@@ -331,11 +328,11 @@ class InceptionC : public CUDANet::Module {
branch7x7dbl_output = branch7x7dbl_5->forward(branch7x7dbl_output); branch7x7dbl_output = branch7x7dbl_5->forward(branch7x7dbl_output);
float *branchPool_output = branchPool_1->forward(d_input); float *branchPool_output = branchPool_1->forward(d_input);
branchPool_output = branchPool_2->forward(branchPool_output); branchPool_output = branchPool_2->forward(branchPool_output);
float *d_output = concat_1->forward(branch1x1_output, branch7x7_output); float *d_output = concat_1->forward(branch1x1_output, branch7x7_output);
d_output = concat_2->forward(d_output, branch7x7dbl_output); d_output = concat_2->forward(d_output, branch7x7dbl_output);
d_output = concat_3->forward(d_output, branchPool_output); d_output = concat_3->forward(d_output, branchPool_output);
return d_output; return d_output;
} }
@@ -357,7 +354,7 @@ class InceptionC : public CUDANet::Module {
BasicConv2d *branch7x7dbl_5; BasicConv2d *branch7x7dbl_5;
CUDANet::Layers::AvgPooling2d *branchPool_1; CUDANet::Layers::AvgPooling2d *branchPool_1;
BasicConv2d *branchPool_2; BasicConv2d *branchPool_2;
CUDANet::Layers::Concat *concat_1; CUDANet::Layers::Concat *concat_1;
CUDANet::Layers::Concat *concat_2; CUDANet::Layers::Concat *concat_2;