Fix inception module prefix

This commit is contained in:
2024-05-27 22:12:08 +02:00
parent 387990edbe
commit 6d6f4d4185

View File

@@ -194,24 +194,25 @@ class InceptionB : public CUDANet::Module {
: inputSize(inputSize), inputChannels(inputChannels) { : inputSize(inputSize), inputChannels(inputChannels) {
// Branch 3x3 // Branch 3x3
branch3x3 = new BasicConv2d( branch3x3 = new BasicConv2d(
inputSize, inputChannels, 384, {3, 3}, {2, 2}, {0, 0}, "branch1x1" inputSize, inputChannels, 384, {3, 3}, {2, 2}, {0, 0},
prefix + ".branch1x1"
); );
addLayer("", branch3x3); addLayer("", branch3x3);
// Branch 3x3dbl // Branch 3x3dbl
branch3x3dbl_1 = new BasicConv2d( branch3x3dbl_1 = new BasicConv2d(
inputSize, inputChannels, 64, {1, 1}, {1, 1}, {0, 0}, inputSize, inputChannels, 64, {1, 1}, {1, 1}, {0, 0},
"branch3x3dbl_1" prefix + ".branch3x3dbl_1"
); );
addLayer("", branch3x3dbl_1); addLayer("", branch3x3dbl_1);
branch3x3dbl_2 = new BasicConv2d( branch3x3dbl_2 = new BasicConv2d(
branch3x3dbl_1->getOutputDims(), 96, 96, {3, 3}, {1, 1}, {1, 1}, branch3x3dbl_1->getOutputDims(), 96, 96, {3, 3}, {1, 1}, {1, 1},
"branch3x3dbl_2" prefix + ".branch3x3dbl_2"
); );
addLayer("", branch3x3dbl_2); addLayer("", branch3x3dbl_2);
branch3x3dbl_3 = new BasicConv2d( branch3x3dbl_3 = new BasicConv2d(
branch3x3dbl_2->getOutputDims(), 96, 96, {3, 3}, {2, 2}, {1, 1}, branch3x3dbl_2->getOutputDims(), 96, 96, {3, 3}, {2, 2}, {1, 1},
"branch3x3dbl_3" prefix + ".branch3x3dbl_3"
); );
addLayer("", branch3x3dbl_3); addLayer("", branch3x3dbl_3);
@@ -282,51 +283,52 @@ class InceptionC : public CUDANet::Module {
: inputSize(inputSize), inputChannels(inputChannels) { : inputSize(inputSize), inputChannels(inputChannels) {
// Branch 1x1 // Branch 1x1
branch1x1 = new BasicConv2d( branch1x1 = new BasicConv2d(
inputSize, inputChannels, 192, {1, 1}, {1, 1}, {0, 0}, "branch1x1" inputSize, inputChannels, 192, {1, 1}, {1, 1}, {0, 0},
prefix + ".branch1x1"
); );
addLayer("", branch1x1); addLayer("", branch1x1);
// Branch 7x7 // Branch 7x7
branch7x7_1 = new BasicConv2d( branch7x7_1 = new BasicConv2d(
inputSize, inputChannels, nChannels_7x7, {1, 1}, {1, 1}, {0, 0}, inputSize, inputChannels, nChannels_7x7, {1, 1}, {1, 1}, {0, 0},
"branch7x7_1" prefix + ".branch7x7_1"
); );
addLayer("", branch7x7_1); addLayer("", branch7x7_1);
branch7x7_2 = new BasicConv2d( branch7x7_2 = new BasicConv2d(
branch7x7_1->getOutputDims(), nChannels_7x7, nChannels_7x7, {1, 7}, branch7x7_1->getOutputDims(), nChannels_7x7, nChannels_7x7, {1, 7},
{1, 1}, {0, 3}, "branch7x7_2" {1, 1}, {0, 3}, prefix + ".branch7x7_2"
); );
addLayer("", branch7x7_2); addLayer("", branch7x7_2);
branch7x7_3 = new BasicConv2d( branch7x7_3 = new BasicConv2d(
branch7x7_2->getOutputDims(), nChannels_7x7, 192, {7, 1}, {1, 1}, branch7x7_2->getOutputDims(), nChannels_7x7, 192, {7, 1}, {1, 1},
{3, 0}, "branch7x7_3" {3, 0}, prefix + ".branch7x7_3"
); );
addLayer("", branch7x7_3); addLayer("", branch7x7_3);
// Branch 7x7dbl // Branch 7x7dbl
branch7x7dbl_1 = new BasicConv2d( branch7x7dbl_1 = new BasicConv2d(
inputSize, inputChannels, nChannels_7x7, {1, 1}, {1, 1}, {0, 0}, inputSize, inputChannels, nChannels_7x7, {1, 1}, {1, 1}, {0, 0},
"branch7x7dbl_1" prefix + ".branch7x7dbl_1"
); );
addLayer("", branch7x7dbl_1); addLayer("", branch7x7dbl_1);
branch7x7dbl_2 = new BasicConv2d( branch7x7dbl_2 = new BasicConv2d(
branch7x7dbl_1->getOutputDims(), nChannels_7x7, nChannels_7x7, branch7x7dbl_1->getOutputDims(), nChannels_7x7, nChannels_7x7,
{7, 1}, {1, 1}, {3, 0}, "branch7x7dbl_2" {7, 1}, {1, 1}, {3, 0}, prefix + ".branch7x7dbl_2"
); );
addLayer("", branch7x7dbl_2); addLayer("", branch7x7dbl_2);
branch7x7dbl_3 = new BasicConv2d( branch7x7dbl_3 = new BasicConv2d(
branch7x7dbl_2->getOutputDims(), nChannels_7x7, nChannels_7x7, branch7x7dbl_2->getOutputDims(), nChannels_7x7, nChannels_7x7,
{1, 7}, {1, 1}, {0, 3}, "branch7x7dbl_3" {1, 7}, {1, 1}, {0, 3}, prefix + ".branch7x7dbl_3"
); );
addLayer("", branch7x7dbl_3); addLayer("", branch7x7dbl_3);
branch7x7dbl_4 = new BasicConv2d( branch7x7dbl_4 = new BasicConv2d(
branch7x7dbl_3->getOutputDims(), nChannels_7x7, nChannels_7x7, branch7x7dbl_3->getOutputDims(), nChannels_7x7, nChannels_7x7,
{7, 1}, {1, 1}, {3, 0}, "branch7x7dbl_4" {7, 1}, {1, 1}, {3, 0}, prefix + ".branch7x7dbl_4"
); );
addLayer("", branch7x7dbl_4); addLayer("", branch7x7dbl_4);
branch7x7dbl_5 = new BasicConv2d( branch7x7dbl_5 = new BasicConv2d(
branch7x7dbl_4->getOutputDims(), nChannels_7x7, 192, {1, 7}, {1, 1}, branch7x7dbl_4->getOutputDims(), nChannels_7x7, 192, {1, 7}, {1, 1},
{0, 3}, "branch7x7dbl_5" {0, 3}, prefix + ".branch7x7dbl_5"
); );
addLayer("", branch7x7dbl_5); addLayer("", branch7x7dbl_5);
@@ -338,7 +340,7 @@ class InceptionC : public CUDANet::Module {
addLayer("", branchPool_1); addLayer("", branchPool_1);
branchPool_2 = new BasicConv2d( branchPool_2 = new BasicConv2d(
branchPool_1->getOutputDims(), inputChannels, 192, {1, 1}, {1, 1}, branchPool_1->getOutputDims(), inputChannels, 192, {1, 1}, {1, 1},
{0, 0}, "branchPool_2" {0, 0}, prefix + ".branchPool_2"
); );
addLayer("", branchPool_2); addLayer("", branchPool_2);
@@ -429,33 +431,33 @@ class InceptionD : public CUDANet::Module {
// Branch 3x3 // Branch 3x3
branch3x3_1 = new BasicConv2d( branch3x3_1 = new BasicConv2d(
inputSize, inputChannels, 192, {1, 1}, {1, 1}, {0, 0}, inputSize, inputChannels, 192, {1, 1}, {1, 1}, {0, 0},
prefix + "branch3x3" prefix + ".branch3x3"
); );
addLayer("", branch3x3_1); addLayer("", branch3x3_1);
branch3x3_2 = new BasicConv2d( branch3x3_2 = new BasicConv2d(
inputSize, 192, 320, {3, 3}, {2, 2}, {0, 0}, prefix + "branch3x3_2" inputSize, 192, 320, {3, 3}, {2, 2}, {0, 0}, prefix + ".branch3x3_2"
); );
addLayer("", branch3x3_2); addLayer("", branch3x3_2);
// Branch 7x7x3 // Branch 7x7x3
branch7x7x3_1 = new BasicConv2d( branch7x7x3_1 = new BasicConv2d(
inputSize, inputChannels, 192, {1, 1}, {1, 1}, {0, 0}, inputSize, inputChannels, 192, {1, 1}, {1, 1}, {0, 0},
prefix + "branch7x7x3_1" prefix + ".branch7x7x3_1"
); );
addLayer("", branch7x7x3_1); addLayer("", branch7x7x3_1);
branch7x7x3_2 = new BasicConv2d( branch7x7x3_2 = new BasicConv2d(
inputSize, 192, 192, {1, 7}, {1, 1}, {0, 3}, inputSize, 192, 192, {1, 7}, {1, 1}, {0, 3},
prefix + "branch7x7x3_2" prefix + ".branch7x7x3_2"
); );
addLayer("", branch7x7x3_2); addLayer("", branch7x7x3_2);
branch7x7x3_3 = new BasicConv2d( branch7x7x3_3 = new BasicConv2d(
inputSize, 192, 192, {7, 1}, {1, 1}, {3, 0}, inputSize, 192, 192, {7, 1}, {1, 1}, {3, 0},
prefix + "branch7x7x3_3" prefix + ".branch7x7x3_3"
); );
addLayer("", branch7x7x3_3); addLayer("", branch7x7x3_3);
branch7x7x3_4 = new BasicConv2d( branch7x7x3_4 = new BasicConv2d(
inputSize, 192, 192, {3, 3}, {2, 2}, {0, 0}, inputSize, 192, 192, {3, 3}, {2, 2}, {0, 0},
prefix + "branch7x7x3_4" prefix + ".branch7x7x3_4"
); );
addLayer("", branch7x7x3_4); addLayer("", branch7x7x3_4);
@@ -524,26 +526,33 @@ class InceptionD : public CUDANet::Module {
class InceptionE : public CUDANet::Module { class InceptionE : public CUDANet::Module {
public: public:
InceptionE(shape2d inputSize, int inputChannels) InceptionE(
const shape2d inputSize,
const int inputChannels,
const std::string &prefix
)
: inputSize(inputSize), inputChannels(inputChannels) { : inputSize(inputSize), inputChannels(inputChannels) {
// Branch 1x1 // Branch 1x1
branch1x1 = new BasicConv2d( branch1x1 = new BasicConv2d(
inputSize, inputChannels, 320, {1, 1}, {1, 1}, {0, 0}, "branch1x1" inputSize, inputChannels, 320, {1, 1}, {1, 1}, {0, 0},
prefix + ".branch1x1"
); );
addLayer("", branch1x1); addLayer("", branch1x1);
// Branch 3x3 // Branch 3x3
branch3x3_1 = new BasicConv2d( branch3x3_1 = new BasicConv2d(
inputSize, inputChannels, 384, {1, 1}, {1, 1}, {0, 0}, "branch3x3_1" inputSize, inputChannels, 384, {1, 1}, {1, 1}, {0, 0},
prefix + ".branch3x3_1"
); );
addLayer("", branch3x3_1); addLayer("", branch3x3_1);
branch3x3_2a = new BasicConv2d( branch3x3_2a = new BasicConv2d(
inputSize, 384, 384, {1, 3}, {1, 1}, {0, 1}, "branch3x3_2a" inputSize, 384, 384, {1, 3}, {1, 1}, {0, 1},
prefix + ".branch3x3_2a"
); );
addLayer("", branch3x3_2a); addLayer("", branch3x3_2a);
branch3x3_2b = new BasicConv2d( branch3x3_2b = new BasicConv2d(
inputSize, 384, 384, {3, 1}, {1, 1}, {1, 0}, "branch3x3_2b" inputSize, 384, 384, {3, 1}, {1, 1}, {1, 0},
prefix + ".branch3x3_2b"
); );
addLayer("", branch3x3_2b); addLayer("", branch3x3_2b);
branch_3x3_2_concat = new CUDANet::Layers::Concat( branch_3x3_2_concat = new CUDANet::Layers::Concat(
@@ -553,19 +562,22 @@ class InceptionE : public CUDANet::Module {
// Branch 3x3dbl // Branch 3x3dbl
branch3x3dbl_1 = new BasicConv2d( branch3x3dbl_1 = new BasicConv2d(
inputSize, inputChannels, 448, {1, 1}, {1, 1}, {0, 0}, inputSize, inputChannels, 448, {1, 1}, {1, 1}, {0, 0},
"branch3x3dbl_1" prefix + ".branch3x3dbl_1"
); );
addLayer("", branch3x3dbl_1); addLayer("", branch3x3dbl_1);
branch3x3dbl_2 = new BasicConv2d( branch3x3dbl_2 = new BasicConv2d(
inputSize, 448, 384, {3, 3}, {1, 1}, {1, 1}, "branch3x3dbl_2" inputSize, 448, 384, {3, 3}, {1, 1}, {1, 1},
prefix + ".branch3x3dbl_2"
); );
addLayer("", branch3x3dbl_2); addLayer("", branch3x3dbl_2);
branch3x3dbl_3a = new BasicConv2d( branch3x3dbl_3a = new BasicConv2d(
inputSize, 384, 384, {1, 3}, {1, 1}, {0, 1}, "branch3x3dbl_3a" inputSize, 384, 384, {1, 3}, {1, 1}, {0, 1},
prefix + ".branch3x3dbl_3a"
); );
addLayer("", branch3x3dbl_3a); addLayer("", branch3x3dbl_3a);
branch3x3dbl_3b = new BasicConv2d( branch3x3dbl_3b = new BasicConv2d(
inputSize, 384, 384, {3, 1}, {1, 1}, {1, 0}, "branch3x3dbl_3b" inputSize, 384, 384, {3, 1}, {1, 1}, {1, 0},
prefix + ".branch3x3dbl_3b"
); );
addLayer("", branch3x3dbl_3b); addLayer("", branch3x3dbl_3b);
branch_3x3dbl_3_concat = new CUDANet::Layers::Concat( branch_3x3dbl_3_concat = new CUDANet::Layers::Concat(
@@ -574,11 +586,13 @@ class InceptionE : public CUDANet::Module {
// Branch Pool // Branch Pool
branchPool_1 = new CUDANet::Layers::AvgPooling2d( branchPool_1 = new CUDANet::Layers::AvgPooling2d(
inputSize, inputChannels, {3, 3}, {1, 1}, {1, 1}, CUDANet::Layers::ActivationType::NONE inputSize, inputChannels, {3, 3}, {1, 1}, {1, 1},
CUDANet::Layers::ActivationType::NONE
); );
addLayer("", branchPool_1); addLayer("", branchPool_1);
branchPool_2 = new BasicConv2d( branchPool_2 = new BasicConv2d(
inputSize, inputChannels, 192, {1, 1}, {1, 1}, {0, 0}, "branchPool_2" inputSize, inputChannels, 192, {1, 1}, {1, 1}, {0, 0},
prefix + ".branchPool_2"
); );
addLayer("", branchPool_2); addLayer("", branchPool_2);
@@ -615,23 +629,25 @@ class InceptionE : public CUDANet::Module {
float *forward(const float *d_input) { float *forward(const float *d_input) {
float *branch1x1_output = branch1x1->forward(d_input); float *branch1x1_output = branch1x1->forward(d_input);
float *branch3x3_output = branch3x3_1->forward(d_input); float *branch3x3_output = branch3x3_1->forward(d_input);
float *branch3x3_2a_output = branch3x3_2a->forward(branch3x3_output); float *branch3x3_2a_output = branch3x3_2a->forward(branch3x3_output);
float *branch3x3_2b_output = branch3x3_2b->forward(branch3x3_output); float *branch3x3_2b_output = branch3x3_2b->forward(branch3x3_output);
branch3x3_output = branch_3x3_2_concat->forward( branch3x3_output = branch_3x3_2_concat->forward(
branch3x3_2a_output, branch3x3_2b_output branch3x3_2a_output, branch3x3_2b_output
); );
float *branch3x3dbl_output = branch3x3dbl_1->forward(d_input); float *branch3x3dbl_output = branch3x3dbl_1->forward(d_input);
branch3x3dbl_output = branch3x3dbl_2->forward(branch3x3dbl_output); branch3x3dbl_output = branch3x3dbl_2->forward(branch3x3dbl_output);
float *branch3x3dbl_3a_output = branch3x3dbl_3a->forward(branch3x3dbl_output); float *branch3x3dbl_3a_output =
float *branch3x3dbl_3b_output = branch3x3dbl_3b->forward(branch3x3dbl_output); branch3x3dbl_3a->forward(branch3x3dbl_output);
float *branch3x3dbl_3b_output =
branch3x3dbl_3b->forward(branch3x3dbl_output);
branch3x3dbl_output = branch_3x3dbl_3_concat->forward( branch3x3dbl_output = branch_3x3dbl_3_concat->forward(
branch3x3dbl_3a_output, branch3x3dbl_3b_output branch3x3dbl_3a_output, branch3x3dbl_3b_output
); );
float *branchPool_output = branchPool_1->forward(d_input); float *branchPool_output = branchPool_1->forward(d_input);
branchPool_output = branchPool_2->forward(branchPool_output); branchPool_output = branchPool_2->forward(branchPool_output);
float *d_output = concat_1->forward(branch1x1_output, branch3x3_output); float *d_output = concat_1->forward(branch1x1_output, branch3x3_output);
d_output = concat_2->forward(d_output, branch3x3dbl_output); d_output = concat_2->forward(d_output, branch3x3dbl_output);