Implement Inception block C

This commit is contained in:
2024-05-26 19:25:00 +02:00
parent 4a60e0142c
commit ca9c4619e9

View File

@@ -101,7 +101,7 @@ class InceptionA : public CUDANet::Module {
// Branch Pool
branchPool_1 = new CUDANet::Layers::AvgPooling2d(
inputSize, inputChannels, {3, 3}, {1, 1},
inputSize, inputChannels, {3, 3}, {1, 1}, {1, 1},
CUDANet::Layers::ActivationType::NONE
);
addLayer("", branchPool_1);
@@ -197,7 +197,7 @@ class InceptionB : public CUDANet::Module {
addLayer("", branch3x3dbl_3);
branchPool = new CUDANet::Layers::MaxPooling2d(
inputSize, inputChannels, {3, 3}, {2, 2},
inputSize, inputChannels, {3, 3}, {2, 2}, {0, 0},
CUDANet::Layers::ActivationType::NONE
);
addLayer(prefix + ".branchPool", branchPool);
@@ -241,3 +241,125 @@ class InceptionB : public CUDANet::Module {
CUDANet::Layers::Concat *concat_1;
CUDANet::Layers::Concat *concat_2;
};
class InceptionC : public CUDANet::Module {
public:
InceptionC(
const dim2d inputSize,
const int inputChannels,
const int nChannels_7x7,
const std::string &prefix
)
: inputSize(inputSize), inputChannels(inputChannels) {
// Branch 1x1
branch1x1 = new BasicConv2d(
inputSize, inputChannels, 192, {1, 1}, {1, 1}, {0, 0}, "branch1x1"
);
// Branch 7x7
branch7x7_1 = new BasicConv2d(
inputSize, inputChannels, nChannels_7x7, {1, 1}, {1, 1}, {0, 0},
"branch7x7_1"
);
branch7x7_2 = new BasicConv2d(
branch7x7_1->getOutputDims(), nChannels_7x7, nChannels_7x7, {1, 7},
{1, 1}, {0, 3}, "branch7x7_2"
);
branch7x7_3 = new BasicConv2d(
branch7x7_2->getOutputDims(), nChannels_7x7, 192, {7, 1}, {1, 1},
{3, 0}, "branch7x7_3"
);
// Branch 7x7dbl
branch7x7dbl_1 = new BasicConv2d(
inputSize, inputChannels, nChannels_7x7, {1, 1}, {1, 1}, {0, 0},
"branch7x7dbl_1"
);
branch7x7dbl_2 = new BasicConv2d(
branch7x7dbl_1->getOutputDims(), nChannels_7x7, nChannels_7x7, {7, 1},
{1, 1}, {3, 0}, "branch7x7dbl_2"
);
branch7x7dbl_3 = new BasicConv2d(
branch7x7dbl_2->getOutputDims(), nChannels_7x7, nChannels_7x7, {1, 7},
{1, 1}, {0, 3}, "branch7x7dbl_3"
);
branch7x7dbl_4 = new BasicConv2d(
branch7x7dbl_3->getOutputDims(), nChannels_7x7, nChannels_7x7, {7, 1},
{1, 1}, {3, 0}, "branch7x7dbl_4"
);
branch7x7dbl_5 = new BasicConv2d(
branch7x7dbl_4->getOutputDims(), nChannels_7x7, 192, {1, 7}, {1, 1},
{0, 3}, "branch7x7dbl_5"
);
// Branch Pool
branchPool_1 = new CUDANet::Layers::AvgPooling2d(
inputSize, inputChannels, {3, 3}, {1, 1}, {1, 1}, CUDANet::Layers::ActivationType::NONE
);
branchPool_2 = new BasicConv2d(
branchPool_1->getOutputDims(), inputChannels, 192, {1, 1}, {1, 1},
{0, 0}, "branchPool_2"
);
// Concat
concat_1 = new CUDANet::Layers::Concat(
branch1x1->getOutputSize(),
branch7x7_3->getOutputSize()
);
concat_2 = new CUDANet::Layers::Concat(
concat_1->getOutputSize(),
branch7x7dbl_5->getOutputSize()
);
concat_3 = new CUDANet::Layers::Concat(
concat_2->getOutputSize(),
branchPool_2->getOutputSize()
);
}
float *forward(const float *d_input) {
float *branch1x1_output = branch1x1->forward(d_input);
float *branch7x7_output = branch7x7_1->forward(d_input);
branch7x7_output = branch7x7_2->forward(branch7x7_output);
branch7x7_output = branch7x7_3->forward(branch7x7_output);
float *branch7x7dbl_output = branch7x7dbl_1->forward(d_input);
branch7x7dbl_output = branch7x7dbl_2->forward(branch7x7dbl_output);
branch7x7dbl_output = branch7x7dbl_3->forward(branch7x7dbl_output);
branch7x7dbl_output = branch7x7dbl_4->forward(branch7x7dbl_output);
branch7x7dbl_output = branch7x7dbl_5->forward(branch7x7dbl_output);
float *branchPool_output = branchPool_1->forward(d_input);
branchPool_output = branchPool_2->forward(branchPool_output);
float *d_output = concat_1->forward(branch1x1_output, branch7x7_output);
d_output = concat_2->forward(d_output, branch7x7dbl_output);
d_output = concat_3->forward(d_output, branchPool_output);
return d_output;
}
private:
dim2d inputSize;
int inputChannels;
BasicConv2d *branch1x1;
BasicConv2d *branch7x7_1;
BasicConv2d *branch7x7_2;
BasicConv2d *branch7x7_3;
BasicConv2d *branch7x7dbl_1;
BasicConv2d *branch7x7dbl_2;
BasicConv2d *branch7x7dbl_3;
BasicConv2d *branch7x7dbl_4;
BasicConv2d *branch7x7dbl_5;
CUDANet::Layers::AvgPooling2d *branchPool_1;
BasicConv2d *branchPool_2;
CUDANet::Layers::Concat *concat_1;
CUDANet::Layers::Concat *concat_2;
CUDANet::Layers::Concat *concat_3;
};