diff --git a/include/layers/avg_pooling.cuh b/include/layers/avg_pooling.cuh index 60070ae..b9f41ab 100644 --- a/include/layers/avg_pooling.cuh +++ b/include/layers/avg_pooling.cuh @@ -6,7 +6,7 @@ namespace CUDANet::Layers { -class AvgPooling2d : public SequentialLayer { +class AvgPooling2d : public SequentialLayer, public TwoDLayer { public: AvgPooling2d( dim2d inputSize, @@ -33,6 +33,8 @@ class AvgPooling2d : public SequentialLayer { */ int getInputSize(); + dim2d getOutputDims(); + private: dim2d inputSize; int nChannels; diff --git a/include/layers/batch_norm.cuh b/include/layers/batch_norm.cuh index eb1cd19..7da46f0 100644 --- a/include/layers/batch_norm.cuh +++ b/include/layers/batch_norm.cuh @@ -8,7 +8,7 @@ namespace CUDANet::Layers { -class BatchNorm2d : public WeightedLayer { +class BatchNorm2d : public WeightedLayer, public TwoDLayer { public: BatchNorm2d(dim2d inputSize, int inputChannels, float epsilon, ActivationType activationType); @@ -64,6 +64,8 @@ class BatchNorm2d : public WeightedLayer { */ int getInputSize(); + dim2d getOutputDims(); + private: dim2d inputSize; diff --git a/include/layers/conv2d.cuh b/include/layers/conv2d.cuh index 19c9003..9cb9978 100644 --- a/include/layers/conv2d.cuh +++ b/include/layers/conv2d.cuh @@ -13,7 +13,7 @@ namespace CUDANet::Layers { * @brief 2D convolutional layer * */ -class Conv2d : public WeightedLayer { +class Conv2d : public WeightedLayer, public TwoDLayer { public: /** * @brief Construct a new Conv 2d layer @@ -102,6 +102,8 @@ class Conv2d : public WeightedLayer { return paddingSize; } + dim2d getOutputDims(); + private: // Inputs dim2d inputSize; diff --git a/include/layers/layer.cuh b/include/layers/layer.cuh index 22f14e7..52b27af 100644 --- a/include/layers/layer.cuh +++ b/include/layers/layer.cuh @@ -11,6 +11,15 @@ typedef std::pair dim2d; namespace CUDANet::Layers { + +class TwoDLayer { + + public: + virtual dim2d getOutputDims() = 0; + +}; + + /** * @brief Basic Sequential Layer * diff --git a/include/layers/max_pooling.cuh b/include/layers/max_pooling.cuh index 7aa35e7..020e643 100644 --- a/include/layers/max_pooling.cuh +++ b/include/layers/max_pooling.cuh @@ -6,7 +6,7 @@ namespace CUDANet::Layers { -class MaxPooling2d : public SequentialLayer { +class MaxPooling2d : public SequentialLayer, public TwoDLayer { public: MaxPooling2d( dim2d inputSize, @@ -33,6 +33,8 @@ class MaxPooling2d : public SequentialLayer { */ int getInputSize(); + dim2d getOutputDims(); + private: dim2d inputSize; int nChannels; diff --git a/src/layers/avg_pooling.cu b/src/layers/avg_pooling.cu index 898fd66..be35ebd 100644 --- a/src/layers/avg_pooling.cu +++ b/src/layers/avg_pooling.cu @@ -61,4 +61,8 @@ int AvgPooling2d::getOutputSize() { int AvgPooling2d::getInputSize() { return inputSize.first * inputSize.second * nChannels; +} + +dim2d AvgPooling2d::getOutputDims() { + return outputSize; } \ No newline at end of file diff --git a/src/layers/batch_norm.cu b/src/layers/batch_norm.cu index 774778c..c311d6f 100644 --- a/src/layers/batch_norm.cu +++ b/src/layers/batch_norm.cu @@ -128,6 +128,10 @@ int BatchNorm2d::getOutputSize() { return inputSize.first * inputSize.second * inputChannels; } +dim2d BatchNorm2d::getOutputDims() { + return inputSize; +} + float *BatchNorm2d::forward(const float *d_input) { // Compute per-channel batch normalization for (int i = 0; i < inputChannels; i++) { diff --git a/src/layers/conv2d.cu b/src/layers/conv2d.cu index b7e83a3..e8069be 100644 --- a/src/layers/conv2d.cu +++ b/src/layers/conv2d.cu @@ -137,4 +137,8 @@ int Conv2d::getOutputSize() { int Conv2d::getInputSize() { return inputSize.first * inputSize.second * inputChannels; +} + +dim2d Conv2d::getOutputDims() { + return outputSize; } \ No newline at end of file diff --git a/src/layers/max_pooling.cu b/src/layers/max_pooling.cu index fd48227..f1819cd 100644 --- a/src/layers/max_pooling.cu +++ b/src/layers/max_pooling.cu @@ -59,4 +59,8 @@ int MaxPooling2d::getOutputSize() { int MaxPooling2d::getInputSize() { return inputSize.first * inputSize.second * nChannels; +} + +dim2d MaxPooling2d::getOutputDims() { + return outputSize; } \ No newline at end of file