mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-05 17:34:21 +00:00
Start implementing inception block A
This commit is contained in:
@@ -21,10 +21,7 @@ class BasicConv2d : public CUDANet::Module {
|
|||||||
const int stride,
|
const int stride,
|
||||||
const int padding,
|
const int padding,
|
||||||
const std::string &prefix
|
const std::string &prefix
|
||||||
)
|
) {
|
||||||
: inputSize(inputSize),
|
|
||||||
inputChannels(inputChannels),
|
|
||||||
outputChannels(outputChannels) {
|
|
||||||
// Create the convolution layer
|
// Create the convolution layer
|
||||||
CUDANet::Layers::Conv2d *conv = new CUDANet::Layers::Conv2d(
|
CUDANet::Layers::Conv2d *conv = new CUDANet::Layers::Conv2d(
|
||||||
inputSize, inputChannels, kernelSize, stride, outputChannels,
|
inputSize, inputChannels, kernelSize, stride, outputChannels,
|
||||||
@@ -43,16 +40,13 @@ class BasicConv2d : public CUDANet::Module {
|
|||||||
}
|
}
|
||||||
|
|
||||||
float* forward(const float* d_input) {
|
float* forward(const float* d_input) {
|
||||||
|
|
||||||
for (auto& layer : layers) {
|
for (auto& layer : layers) {
|
||||||
d_input = layer.second->forward(d_input);
|
d_input = layer.second->forward(d_input);
|
||||||
}
|
}
|
||||||
return d_input;
|
return d_input;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
int inputSize;
|
|
||||||
int inputChannels;
|
|
||||||
int outputChannels;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class InceptionA : public CUDANet::Module {
|
class InceptionA : public CUDANet::Module {
|
||||||
@@ -71,34 +65,53 @@ class InceptionA : public CUDANet::Module {
|
|||||||
CUDANet::Module *branch1x1 = new BasicConv2d(
|
CUDANet::Module *branch1x1 = new BasicConv2d(
|
||||||
inputSize, inputChannels, 64, 1, 1, 0, prefix + ".branch1x1"
|
inputSize, inputChannels, 64, 1, 1, 0, prefix + ".branch1x1"
|
||||||
);
|
);
|
||||||
|
addLayer("", branch1x1);
|
||||||
|
|
||||||
// Branch 5x5
|
// Branch 5x5
|
||||||
CUDANet::Module *branch5x5_1 = new BasicConv2d(
|
CUDANet::Module *branch5x5_1 = new BasicConv2d(
|
||||||
inputSize, inputChannels, 48, 1, 1, 0, prefix + ".branch5x5_1"
|
inputSize, inputChannels, 48, 1, 1, 0, prefix + ".branch5x5_1"
|
||||||
);
|
);
|
||||||
|
addLayer("", branch5x5_1);
|
||||||
CUDANet::Module *branch5x5_2 = new BasicConv2d(
|
CUDANet::Module *branch5x5_2 = new BasicConv2d(
|
||||||
inputSize, 48, 64, 5, 1, 2, prefix + ".branch5x5_2"
|
inputSize, 48, 64, 5, 1, 2, prefix + ".branch5x5_2"
|
||||||
);
|
);
|
||||||
|
addLayer("", branch5x5_2);
|
||||||
|
|
||||||
// Branch 3x3
|
// Branch 3x3
|
||||||
CUDANet::Module *branch3x3_1 = new BasicConv2d(
|
CUDANet::Module *branch3x3_1 = new BasicConv2d(
|
||||||
inputSize, inputChannels, 64, 1, 1, 0, prefix + ".branch3x3_1"
|
inputSize, inputChannels, 64, 1, 1, 0, prefix + ".branch3x3_1"
|
||||||
);
|
);
|
||||||
|
addLayer("", branch3x3_1);
|
||||||
CUDANet::Module *branch3x3_2 = new BasicConv2d(
|
CUDANet::Module *branch3x3_2 = new BasicConv2d(
|
||||||
inputSize, 64, 96, 3, 1, 1, prefix + ".branch3x3_2"
|
inputSize, 64, 96, 3, 1, 1, prefix + ".branch3x3_2"
|
||||||
);
|
);
|
||||||
|
addLayer("", branch3x3_2);
|
||||||
CUDANet::Module *branch3x3_3 = new BasicConv2d(
|
CUDANet::Module *branch3x3_3 = new BasicConv2d(
|
||||||
inputSize, 96, 96, 3, 1, 1, prefix + ".branch3x3_3"
|
inputSize, 96, 96, 3, 1, 1, prefix + ".branch3x3_3"
|
||||||
);
|
);
|
||||||
|
addLayer("", branch3x3_3);
|
||||||
|
|
||||||
// Branch Pool
|
// Branch Pool
|
||||||
CUDANet::Module *branchPool = new BasicConv2d(
|
CUDANet::Module *branchPool = new BasicConv2d(
|
||||||
inputSize, inputChannels, poolFeatures, 1, 1, 0, prefix + ".branchPool"
|
inputSize, inputChannels, poolFeatures, 1, 1, 0, prefix + ".branchPool"
|
||||||
);
|
);
|
||||||
|
addLayer("", branchPool);
|
||||||
|
|
||||||
|
// Concat
|
||||||
|
concat_1 = new CUDANet::Layers::Concat(
|
||||||
|
branch1x1->getOutputSize(), branch5x5_2->getOutputSize()
|
||||||
|
);
|
||||||
|
concat_2 = new CUDANet::Layers::Concat(
|
||||||
|
concat_1->getOutputSize(), branch3x3_3->getOutputSize()
|
||||||
|
);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int inputSize;
|
int inputSize;
|
||||||
int inputChannels;
|
int inputChannels;
|
||||||
int poolFeatures;
|
int poolFeatures;
|
||||||
|
|
||||||
|
CUDANet::Layers::Concat *concat_1;
|
||||||
|
CUDANet::Layers::Concat *concat_2;
|
||||||
};
|
};
|
||||||
Reference in New Issue
Block a user