Add inception block A

This commit is contained in:
2024-05-19 18:19:07 +02:00
parent d0b974dd9f
commit e8bffe22d5

View File

@@ -20,29 +20,85 @@ class BasicConv2d : public CUDANet::Module {
const int kernelSize, const int kernelSize,
const int stride, const int stride,
const int padding, const int padding,
const std::string& prefix const std::string &prefix
) )
: inputSize(inputSize), : inputSize(inputSize),
inputChannels(inputChannels), inputChannels(inputChannels),
outputChannels(outputChannels) { outputChannels(outputChannels) {
// Create the convolution layer // Create the convolution layer
CUDANet::Layers::Conv2d *conv = new CUDANet::Layers::Conv2d( CUDANet::Layers::Conv2d *conv = new CUDANet::Layers::Conv2d(
inputSize, inputChannels, kernelSize, stride, outputChannels, padding, CUDANet::Layers::ActivationType::NONE inputSize, inputChannels, kernelSize, stride, outputChannels,
padding, CUDANet::Layers::ActivationType::NONE
); );
int batchNormSize = conv->getOutputSize(); int batchNormSize = conv->getOutputSize();
CUDANet::Layers::BatchNorm *batchNorm = CUDANet::Layers::BatchNorm *batchNorm = new CUDANet::Layers::BatchNorm(
new CUDANet::Layers::BatchNorm( batchNormSize, outputChannels, 1e-3f,
batchNormSize, outputChannels, 1e-3f, CUDANet::Layers::ActivationType::RELU CUDANet::Layers::ActivationType::RELU
); );
addLayer(prefix + ".conv", conv); addLayer(prefix + ".conv", conv);
addLayer(prefix + ".bn", batchNorm); addLayer(prefix + ".bn", batchNorm);
} }
float* forward(const float* d_input) {
for (auto& layer : layers) {
d_input = layer.second->forward(d_input);
}
return d_input;
}
private: private:
int inputSize; int inputSize;
int inputChannels; int inputChannels;
int outputChannels; int outputChannels;
}; };
class InceptionA : public CUDANet::Module {
public:
InceptionA(
const int inputSize,
const int inputChannels,
const int poolFeatures,
const std::string &prefix
)
: inputSize(inputSize),
inputChannels(inputChannels),
poolFeatures(poolFeatures) {
// Branch 1x1
CUDANet::Module *branch1x1 = new BasicConv2d(
inputSize, inputChannels, 64, 1, 1, 0, prefix + ".branch1x1"
);
// Branch 5x5
CUDANet::Module *branch5x5_1 = new BasicConv2d(
inputSize, inputChannels, 48, 1, 1, 0, prefix + ".branch5x5_1"
);
CUDANet::Module *branch5x5_2 = new BasicConv2d(
inputSize, 48, 64, 5, 1, 2, prefix + ".branch5x5_2"
);
// Branch 3x3
CUDANet::Module *branch3x3_1 = new BasicConv2d(
inputSize, inputChannels, 64, 1, 1, 0, prefix + ".branch3x3_1"
);
CUDANet::Module *branch3x3_2 = new BasicConv2d(
inputSize, 64, 96, 3, 1, 1, prefix + ".branch3x3_2"
);
CUDANet::Module *branch3x3_3 = new BasicConv2d(
inputSize, 96, 96, 3, 1, 1, prefix + ".branch3x3_3"
);
// Branch Pool
CUDANet::Module *branchPool = new BasicConv2d(
inputSize, inputChannels, poolFeatures, 1, 1, 0, prefix + ".branchPool"
);
}
private:
int inputSize;
int inputChannels;
int poolFeatures;
};