Start implementing inception block A

This commit is contained in:
2024-05-19 20:22:26 +02:00
parent b5fb205df8
commit ba4bee90ad

View File

@@ -21,10 +21,7 @@ class BasicConv2d : public CUDANet::Module {
const int stride, const int stride,
const int padding, const int padding,
const std::string &prefix const std::string &prefix
) ) {
: inputSize(inputSize),
inputChannels(inputChannels),
outputChannels(outputChannels) {
// Create the convolution layer // Create the convolution layer
CUDANet::Layers::Conv2d *conv = new CUDANet::Layers::Conv2d( CUDANet::Layers::Conv2d *conv = new CUDANet::Layers::Conv2d(
inputSize, inputChannels, kernelSize, stride, outputChannels, inputSize, inputChannels, kernelSize, stride, outputChannels,
@@ -43,16 +40,13 @@ class BasicConv2d : public CUDANet::Module {
} }
float* forward(const float* d_input) { float* forward(const float* d_input) {
for (auto& layer : layers) { for (auto& layer : layers) {
d_input = layer.second->forward(d_input); d_input = layer.second->forward(d_input);
} }
return d_input; return d_input;
} }
private:
int inputSize;
int inputChannels;
int outputChannels;
}; };
class InceptionA : public CUDANet::Module { class InceptionA : public CUDANet::Module {
@@ -71,34 +65,53 @@ class InceptionA : public CUDANet::Module {
CUDANet::Module *branch1x1 = new BasicConv2d( CUDANet::Module *branch1x1 = new BasicConv2d(
inputSize, inputChannels, 64, 1, 1, 0, prefix + ".branch1x1" inputSize, inputChannels, 64, 1, 1, 0, prefix + ".branch1x1"
); );
addLayer("", branch1x1);
// Branch 5x5 // Branch 5x5
CUDANet::Module *branch5x5_1 = new BasicConv2d( CUDANet::Module *branch5x5_1 = new BasicConv2d(
inputSize, inputChannels, 48, 1, 1, 0, prefix + ".branch5x5_1" inputSize, inputChannels, 48, 1, 1, 0, prefix + ".branch5x5_1"
); );
addLayer("", branch5x5_1);
CUDANet::Module *branch5x5_2 = new BasicConv2d( CUDANet::Module *branch5x5_2 = new BasicConv2d(
inputSize, 48, 64, 5, 1, 2, prefix + ".branch5x5_2" inputSize, 48, 64, 5, 1, 2, prefix + ".branch5x5_2"
); );
addLayer("", branch5x5_2);
// Branch 3x3 // Branch 3x3
CUDANet::Module *branch3x3_1 = new BasicConv2d( CUDANet::Module *branch3x3_1 = new BasicConv2d(
inputSize, inputChannels, 64, 1, 1, 0, prefix + ".branch3x3_1" inputSize, inputChannels, 64, 1, 1, 0, prefix + ".branch3x3_1"
); );
addLayer("", branch3x3_1);
CUDANet::Module *branch3x3_2 = new BasicConv2d( CUDANet::Module *branch3x3_2 = new BasicConv2d(
inputSize, 64, 96, 3, 1, 1, prefix + ".branch3x3_2" inputSize, 64, 96, 3, 1, 1, prefix + ".branch3x3_2"
); );
addLayer("", branch3x3_2);
CUDANet::Module *branch3x3_3 = new BasicConv2d( CUDANet::Module *branch3x3_3 = new BasicConv2d(
inputSize, 96, 96, 3, 1, 1, prefix + ".branch3x3_3" inputSize, 96, 96, 3, 1, 1, prefix + ".branch3x3_3"
); );
addLayer("", branch3x3_3);
// Branch Pool // Branch Pool
CUDANet::Module *branchPool = new BasicConv2d( CUDANet::Module *branchPool = new BasicConv2d(
inputSize, inputChannels, poolFeatures, 1, 1, 0, prefix + ".branchPool" inputSize, inputChannels, poolFeatures, 1, 1, 0, prefix + ".branchPool"
); );
addLayer("", branchPool);
// Concat
concat_1 = new CUDANet::Layers::Concat(
branch1x1->getOutputSize(), branch5x5_2->getOutputSize()
);
concat_2 = new CUDANet::Layers::Concat(
concat_1->getOutputSize(), branch3x3_3->getOutputSize()
);
} }
private: private:
int inputSize; int inputSize;
int inputChannels; int inputChannels;
int poolFeatures; int poolFeatures;
CUDANet::Layers::Concat *concat_1;
CUDANet::Layers::Concat *concat_2;
}; };