From e4d05931d4edd60798c657dad497938dee1b665c Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 19 Nov 2025 21:44:19 +0100 Subject: [PATCH] Migrate MaxPool2d layer to Tensors --- include/backend.hpp | 10 ++++ include/backend/cuda.cuh | 10 ++++ include/kernels/pooling.cuh | 31 +++++------ include/layers/input.hpp | 66 ----------------------- include/layers/max_pool.hpp | 51 ++++++++++++++++++ include/layers/max_pooling.hpp | 63 ---------------------- include/layers/output.hpp | 59 --------------------- src/backends/cuda/kernels/pooling.cu | 68 ++++++++++++------------ src/backends/cuda/layer_ops.cu | 27 ++++++++++ src/backends/cuda/layers/input.cu | 22 -------- src/backends/cuda/layers/max_pooling.cu | 38 -------------- src/backends/cuda/layers/output.cu | 14 ----- src/layers/conv2d.cpp | 2 +- src/layers/input.cpp | 37 ------------- src/layers/max_pool.cpp | 70 +++++++++++++++++++++++++ src/layers/max_pooling.cpp | 67 ----------------------- src/layers/output.cu | 34 ------------ 17 files changed, 215 insertions(+), 454 deletions(-) delete mode 100644 include/layers/input.hpp create mode 100644 include/layers/max_pool.hpp delete mode 100644 include/layers/max_pooling.hpp delete mode 100644 include/layers/output.hpp delete mode 100644 src/backends/cuda/layers/input.cu delete mode 100644 src/backends/cuda/layers/max_pooling.cu delete mode 100644 src/backends/cuda/layers/output.cu delete mode 100644 src/layers/input.cpp create mode 100644 src/layers/max_pool.cpp delete mode 100644 src/layers/max_pooling.cpp delete mode 100644 src/layers/output.cu diff --git a/include/backend.hpp b/include/backend.hpp index d877ad8..1edb47e 100644 --- a/include/backend.hpp +++ b/include/backend.hpp @@ -52,6 +52,16 @@ class Backend { const CUDANet::Shape stride_shape, const CUDANet::Shape out_shape ) = 0; + + virtual CUDANet::Tensor& maxPool2d( + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + CUDANet::Shape input_shape, + CUDANet::Shape pool_shape, + CUDANet::Shape stride_shape, + CUDANet::Shape padding_shape, + CUDANet::Shape output_shape + ) = 0; }; } // namespace CUDANet \ No newline at end of file diff --git a/include/backend/cuda.cuh b/include/backend/cuda.cuh index 0ee5247..1dbd649 100644 --- a/include/backend/cuda.cuh +++ b/include/backend/cuda.cuh @@ -48,6 +48,16 @@ class CUDA : public Backend { const CUDANet::Shape stride_shape, const CUDANet::Shape out_shape ) override; + + CUDANet::Tensor& CUDA::maxPool2d( + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + CUDANet::Shape input_shape, + CUDANet::Shape pool_shape, + CUDANet::Shape stride_shape, + CUDANet::Shape padding_shape, + CUDANet::Shape output_shape + ) override; }; } // namespace CUDANet::Backend \ No newline at end of file diff --git a/include/kernels/pooling.cuh b/include/kernels/pooling.cuh index 3ca8578..ed92256 100644 --- a/include/kernels/pooling.cuh +++ b/include/kernels/pooling.cuh @@ -1,33 +1,28 @@ -#ifndef CUDANET_POOLING_H -#define CUDANET_POOLING_H +#pragma once #include #include "layer.hpp" namespace CUDANet::Kernels { -__global__ void max_pooling( +__global__ void max_pool( const float* __restrict__ d_input, float* __restrict__ d_output, - const shape2d inputSize, - const shape2d outputSize, - const int nChannels, - const shape2d poolingSize, - const shape2d stride, - const shape2d padding + const Shape input_shape, + const Shape output_shape, + const Shape pool_shape, + const Shape stride_shape, + const Shape padding_shape ); -__global__ void avg_pooling( +__global__ void avg_pool( const float* __restrict__ d_input, float* __restrict__ d_output, - const shape2d inputSize, - const shape2d outputSize, - const int nChannels, - const shape2d poolingSize, - const shape2d stride, - const shape2d padding + const Shape input_shape, + const Shape output_shape, + const Shape pool_shape, + const Shape stride_shape, + const Shape padding_shape ); } // namespace CUDANet::Kernels - -#endif // CUDANET_POOLING_H \ No newline at end of file diff --git a/include/layers/input.hpp b/include/layers/input.hpp deleted file mode 100644 index ab84c75..0000000 --- a/include/layers/input.hpp +++ /dev/null @@ -1,66 +0,0 @@ -#ifndef CUDANET_INPUT_LAYER_H -#define CUDANET_INPUT_LAYER_H - -#include "layer.hpp" - -namespace CUDANet::Layers { - -/** - * @brief Input layer, just copies the input to the device - * - */ -class Input : public Layer { - public: - /** - * @brief Create a new Input layer - * - * @param inputSize Size of the input vector - */ - explicit Input(int inputSize); - - /** - * @brief Destroy the Input layer - * - */ - ~Input(); - - /** - * @brief Forward pass of the input layer. Just copies the input to the - * device - * - * @param input Host pointer to the input vector - * @return Device pointer to the output vector - */ - float* forward(const float* input); - - /** - * @brief Get output size - * - * @return int output size - */ - int get_output_size(); - - /** - * @brief Get input size - * - * @return int input size - */ - int getInputSize(); - - private: - int inputSize; - - float* forwardCPU(const float* input); - -#ifdef USE_CUDA - float* d_output; - - float* forwardCUDA(const float* input); - void initCUDA(); - void delCUDA(); -#endif -}; - -} // namespace CUDANet::Layers - -#endif // CUDANET_INPUT_LAYER_H \ No newline at end of file diff --git a/include/layers/max_pool.hpp b/include/layers/max_pool.hpp new file mode 100644 index 0000000..0adafe9 --- /dev/null +++ b/include/layers/max_pool.hpp @@ -0,0 +1,51 @@ +#pragma once + +#include "layer.hpp" + +namespace CUDANet::Layers { + +class MaxPool2d : public Layer { + public: + MaxPool2d( + CUDANet::Shape input_shape, + CUDANet::Shape pooling_shape, + CUDANet::Shape stride_shape, + CUDANet::Shape padding_shape, + CUDANet::Backend* backend + ); + ~MaxPool2d(); + + CUDANet::Tensor& forward(CUDANet::Tensor &input) override; + + CUDANet::Shape input_shape() override; + + CUDANet::Shape output_shape() override; + + size_t input_size() override; + + size_t output_size() override; + + void set_weights(void *input) override; + + CUDANet::Tensor& get_weights() override; + + void set_biases(void *input) override; + + CUDANet::Tensor& get_biases() override; + + + + private: + CUDANet::Shape in_shape; + + CUDANet::Shape pooling_shape; + CUDANet::Shape stride_shape; + CUDANet::Shape padding_shape; + + CUDANet::Shape out_shape; + CUDANet::Tensor output; + + CUDANet::Backend *backend; +}; + +} // namespace CUDANet::Layers diff --git a/include/layers/max_pooling.hpp b/include/layers/max_pooling.hpp deleted file mode 100644 index bcc66cf..0000000 --- a/include/layers/max_pooling.hpp +++ /dev/null @@ -1,63 +0,0 @@ -#ifndef CUDANET_MAX_POOLING_H -#define CUDANET_MAX_POOLING_H - -#include "activation.hpp" -#include "layer.hpp" - -namespace CUDANet::Layers { - -class MaxPooling2d : public Layer, public TwoDLayer { - public: - MaxPooling2d( - shape2d inputSize, - int nChannels, - shape2d poolingSize, - shape2d stride, - shape2d padding, - ActivationType activationType - ); - ~MaxPooling2d(); - - float* forward(const float* input); - - /** - * @brief Get output size - * - * @return int output size - */ - int get_output_size(); - - /** - * @brief Get input size - * - * @return int input size - */ - int getInputSize(); - - shape2d getOutputDims(); - - private: - shape2d inputSize; - int nChannels; - shape2d poolingSize; - shape2d stride; - shape2d padding; - - shape2d outputSize; - - Activation* activation; - - float* forwardCPU(const float* input); - -#ifdef USE_CUDA - float* d_output; - float* forwardCUDA(const float* d_input); - - void initCUDA(); - void delCUDA(); -#endif -}; - -} // namespace CUDANet::Layers - -#endif // CUDANET_MAX_POOLING_H \ No newline at end of file diff --git a/include/layers/output.hpp b/include/layers/output.hpp deleted file mode 100644 index 28e5634..0000000 --- a/include/layers/output.hpp +++ /dev/null @@ -1,59 +0,0 @@ -#ifndef CUDANET_OUTPUT_LAYER_H -#define CUDANET_OUTPUT_LAYER_H - -#include "layer.hpp" - -namespace CUDANet::Layers { - -class Output : public Layer { - public: - /** - * @brief Create a new Output layer - * - * @param inputSize Size of the input vector - */ - explicit Output(int inputSize); - - /** - * @brief Destroy the Output layer - * - */ - ~Output(); - - /** - * @brief Forward pass of the output layer. Just copies the input from - * device to host - * - * @param input Device pointer to the input vector - * @return Host pointer to the output vector - */ - float* forward(const float* input); - - /** - * @brief Get output size - * - * @return int output size - */ - int get_output_size(); - - /** - * @brief Get input size - * - * @return int input size - */ - int getInputSize(); - - private: - int inputSize; - float* h_output; - - float* forwardCPU(const float* input); - -#ifdef USE_CUDA - float* forwardCUDA(const float* input); -#endif -}; - -} // namespace CUDANet::Layers - -#endif // CUDANET_OUTPUT_LAYER_H \ No newline at end of file diff --git a/src/backends/cuda/kernels/pooling.cu b/src/backends/cuda/kernels/pooling.cu index dbb5e09..29904fa 100644 --- a/src/backends/cuda/kernels/pooling.cu +++ b/src/backends/cuda/kernels/pooling.cu @@ -4,35 +4,34 @@ using namespace CUDANet; -__global__ void Kernels::max_pooling( +__global__ void Kernels::max_pool( const float* __restrict__ d_input, float* __restrict__ d_output, - const shape2d inputSize, - const shape2d outputSize, - const int nChannels, - const shape2d poolingSize, - const shape2d stride, - const shape2d padding + const Shape input_shape, + const Shape output_shape, + const Shape pool_shape, + const Shape stride_shape, + const Shape padding_shape ) { int j = blockDim.x * blockIdx.x + threadIdx.x; int i = blockDim.y * blockIdx.y + threadIdx.y; int c = blockDim.z * blockIdx.z + threadIdx.z; - if (i >= outputSize.first || j >= outputSize.second || c >= nChannels) { + if (i >= output_shape[0] || j >= output_shape[1] || c >= output_shape[2]) { return; } float max = 0.0f; - for (int k = 0; k < poolingSize.first; k++) { - for (int l = 0; l < poolingSize.second; l++) { - int inputRow = i * stride.first + k - padding.first; - int inputCol = j * stride.second + l - padding.second; + for (int k = 0; k < pool_shape[0]; k++) { + for (int l = 0; l < pool_shape[1]; l++) { + int inputRow = i * stride_shape[0] + k - padding_shape[0]; + int inputCol = j * stride_shape[1] + l - padding_shape[1]; - if (inputRow >= 0 && inputRow < inputSize.first && inputCol >= 0 && - inputCol < inputSize.second) { - int inputIndex = c * inputSize.first * inputSize.second + - inputRow * inputSize.second + inputCol; + if (inputRow >= 0 && inputRow < input_shape[0] && inputCol >= 0 && + inputCol < input_shape[1]) { + int inputIndex = c * input_shape[0] * input_shape[1] + + inputRow * input_shape[1] + inputCol; if (d_input[inputIndex] > max) { max = d_input[inputIndex]; } @@ -41,45 +40,44 @@ __global__ void Kernels::max_pooling( } d_output - [c * outputSize.first * outputSize.second + i * outputSize.second + j] = + [c * output_shape[0] * output_shape[1] + i * output_shape[1] + j] = max; } -__global__ void Kernels::avg_pooling( +__global__ void Kernels::avg_pool( const float* __restrict__ d_input, float* __restrict__ d_output, - const shape2d inputSize, - const shape2d outputSize, - const int nChannels, - const shape2d poolingSize, - const shape2d stride, - const shape2d padding + const Shape input_shape, + const Shape output_shape, + const Shape pool_shape, + const Shape stride_shape, + const Shape padding_shape ) { int j = blockDim.x * blockIdx.x + threadIdx.x; int i = blockDim.y * blockIdx.y + threadIdx.y; int c = blockDim.z * blockIdx.z + threadIdx.z; - if (i >= outputSize.first || j >= outputSize.second || c >= nChannels) { + if (i >= output_shape[0] || j >= output_shape[1] || c >= output_shape[2]) { return; } float sum = 0.0f; - for (int k = 0; k < poolingSize.first; k++) { - for (int l = 0; l < poolingSize.second; l++) { - int inputRow = i * stride.first + k - padding.first; - int inputCol = j * stride.second + l - padding.second; + for (int k = 0; k < pool_shape[0]; k++) { + for (int l = 0; l < pool_shape[1]; l++) { + int inputRow = i * stride_shape[0] + k - padding_shape[0]; + int inputCol = j * stride_shape[1] + l - padding_shape[1]; - if (inputRow >= 0 && inputRow < inputSize.first && inputCol >= 0 && - inputCol < inputSize.second) { - int inputIndex = c * inputSize.first * inputSize.second + - inputRow * inputSize.second + inputCol; + if (inputRow >= 0 && inputRow < input_shape[0] && inputCol >= 0 && + inputCol < input_shape[1]) { + int inputIndex = c * input_shape[0] * input_shape[1] + + inputRow * input_shape[1] + inputCol; sum += d_input[inputIndex]; } } } d_output - [c * outputSize.first * outputSize.second + i * outputSize.second + j] = - sum / (poolingSize.first * poolingSize.second); + [c * output_shape[0] * output_shape[1] + i * output_shape[1] + j] = + sum / (pool_shape[0] * pool_shape[1]); } \ No newline at end of file diff --git a/src/backends/cuda/layer_ops.cu b/src/backends/cuda/layer_ops.cu index bcfe0fb..129948b 100644 --- a/src/backends/cuda/layer_ops.cu +++ b/src/backends/cuda/layer_ops.cu @@ -2,6 +2,7 @@ #include "kernels/activation_functions.cuh" #include "kernels/convolution.cuh" #include "kernels/matmul.cuh" +#include "kernels/pooling.cuh" #include "utils/cuda_helper.cuh" using namespace CUDANet::Backend; @@ -108,5 +109,31 @@ CUDANet::Tensor& CUDA::conv2d( CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaDeviceSynchronize()); + return output; +} + +CUDANet::Tensor& CUDA::maxPool2d( + const CUDANet::Tensor& input, + CUDANet::Tensor& output, + CUDANet::Shape input_shape, + CUDANet::Shape pool_shape, + CUDANet::Shape stride_shape, + CUDANet::Shape padding_shape, + CUDANet::Shape output_shape +) { + dim3 block(8, 8, 8); + dim3 grid( + (output_shape[0] + block.x - 1) / block.x, + (output_shape[1] + block.y - 1) / block.y, + (output_shape[2] + block.z - 1) / block.z + ); + + Kernels::max_pool<<>>( + input.data(), output.data(), input_shape, output_shape, pool_shape, + stride_shape, padding_shape + ); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + return output; } \ No newline at end of file diff --git a/src/backends/cuda/layers/input.cu b/src/backends/cuda/layers/input.cu deleted file mode 100644 index 05b44be..0000000 --- a/src/backends/cuda/layers/input.cu +++ /dev/null @@ -1,22 +0,0 @@ -#include "cuda_helper.cuh" -#include "input.hpp" - -using namespace CUDANet::Layers; - -void Input::initCUDA() { - d_output = nullptr; - CUDA_CHECK(cudaMalloc((void**)&d_output, sizeof(float) * inputSize)); -} - -void Input::delCUDA() { - cudaFree(d_output); -} - -float* Input::forwardCUDA(const float* input) { - CUDA_CHECK(cudaMemcpy( - d_output, input, sizeof(float) * inputSize, cudaMemcpyHostToDevice - )); - CUDA_CHECK(cudaDeviceSynchronize()); - - return d_output; -} \ No newline at end of file diff --git a/src/backends/cuda/layers/max_pooling.cu b/src/backends/cuda/layers/max_pooling.cu deleted file mode 100644 index 6fea6a8..0000000 --- a/src/backends/cuda/layers/max_pooling.cu +++ /dev/null @@ -1,38 +0,0 @@ -#include "cuda_helper.cuh" -#include "max_pooling.hpp" -#include "pooling.cuh" - -using namespace CUDANet::Layers; - -void MaxPooling2d::initCUDA() { - d_output = nullptr; - CUDA_CHECK(cudaMalloc( - (void**)&d_output, - sizeof(float) * outputSize.first * outputSize.second * nChannels - )); -} - -void MaxPooling2d::delCUDA() { - cudaFree(d_output); -} - - -float* MaxPooling2d::forwardCUDA(const float* d_input) { - dim3 block(8, 8, 8); - dim3 grid( - (outputSize.first + block.x - 1) / block.x, - (outputSize.second + block.y - 1) / block.y, - (nChannels + block.z - 1) / block.z - ); - - Kernels::max_pooling<<>>( - d_input, d_output, inputSize, outputSize, nChannels, poolingSize, - stride, padding - ); - CUDA_CHECK(cudaGetLastError()); - - activation->activate(d_output); - CUDA_CHECK(cudaDeviceSynchronize()); - - return d_output; -} \ No newline at end of file diff --git a/src/backends/cuda/layers/output.cu b/src/backends/cuda/layers/output.cu deleted file mode 100644 index a6dfd96..0000000 --- a/src/backends/cuda/layers/output.cu +++ /dev/null @@ -1,14 +0,0 @@ -#include "output.hpp" - -#include "cuda_helper.cuh" - -using namespace CUDANet::Layers; - -float* Output::forwardCUDA(const float* input) { - CUDA_CHECK(cudaMemcpy( - h_output, input, sizeof(float) * inputSize, cudaMemcpyDeviceToHost - )); - CUDA_CHECK(cudaDeviceSynchronize()); - - return h_output; -} \ No newline at end of file diff --git a/src/layers/conv2d.cpp b/src/layers/conv2d.cpp index d3e8f6d..8d1a4fc 100644 --- a/src/layers/conv2d.cpp +++ b/src/layers/conv2d.cpp @@ -83,7 +83,7 @@ Conv2d::Conv2d( Conv2d::~Conv2d() {} -CUDANet::Tensor& Conv2d::forward(CUDANet::Tensor& input) { +CUDANet::Tensor& Conv2d::forward( CUDANet::Tensor& input) { output.zero(); backend->conv2d( weights, diff --git a/src/layers/input.cpp b/src/layers/input.cpp deleted file mode 100644 index 8c9affc..0000000 --- a/src/layers/input.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include - -#include "input.hpp" - -using namespace CUDANet::Layers; - -Input::Input(int inputSize) : inputSize(inputSize) { -#ifdef USE_CUDA - initCUDA(); -#endif -} - -Input::~Input() { -#ifdef USE_CUDA - delCUDA(); -#endif -} - -float* Input::forwardCPU(const float* input) { - throw std::logic_error("Not implemented"); -} - -float* Input::forward(const float* input) { -#ifdef USE_CUDA - return forwardCUDA(input); -#else - return forwardCPU(input); -#endif -} - -int Input::get_output_size() { - return inputSize; -} - -int Input::getInputSize() { - return inputSize; -} \ No newline at end of file diff --git a/src/layers/max_pool.cpp b/src/layers/max_pool.cpp new file mode 100644 index 0000000..20a7666 --- /dev/null +++ b/src/layers/max_pool.cpp @@ -0,0 +1,70 @@ +#include "max_pool.hpp" + +#include + +using namespace CUDANet::Layers; + +MaxPool2d::MaxPool2d( + CUDANet::Shape input_shape, + CUDANet::Shape pooling_shape, + CUDANet::Shape stride_shape, + CUDANet::Shape padding_shape, + CUDANet::Backend* backend +) + : in_shape(input_shape), + pooling_shape(pooling_shape), + stride_shape(stride_shape), + padding_shape(padding_shape), + backend(backend) { + size_t out_h = (in_shape[0] + 2 * padding_shape[0] - pooling_shape[0]) / + stride_shape[0] + + 1; + size_t out_w = (in_shape[1] + 2 * padding_shape[1] - pooling_shape[1]) / + stride_shape[1] + + 1; + + out_shape.resize(3); + out_shape[0] = out_h; + out_shape[1] = out_w; + out_shape[2] = in_shape[2]; + + output = CUDANet::Tensor( + Shape{out_shape[0] * out_shape[1] * out_shape[3]}, + CUDANet::DType::FLOAT32, backend + ); +} + +MaxPool2d::~MaxPool2d() {} + +CUDANet::Tensor& MaxPool2d::forward(CUDANet::Tensor& input) { + output.zero(); + backend->maxPool2d( + input, output, in_shape, pooling_shape, stride_shape, padding_shape, + out_shape + ); + return output; +} + +CUDANet::Shape MaxPool2d::input_shape() { + return in_shape; +} + +CUDANet::Shape MaxPool2d::output_shape() { + return out_shape; +} + +size_t MaxPool2d::input_size() { + return sizeof(float) * in_shape[0] * in_shape[1] * in_shape[2]; +} + +size_t MaxPool2d::output_size() { + return sizeof(float) * out_shape[0] * out_shape[1] * out_shape[2]; +} + +void MaxPool2d::set_weights(void* input) {} + +CUDANet::Tensor& MaxPool2d::get_weights() {} + +void MaxPool2d::set_biases(void* input) {} + +CUDANet::Tensor& MaxPool2d::get_biases() {} \ No newline at end of file diff --git a/src/layers/max_pooling.cpp b/src/layers/max_pooling.cpp deleted file mode 100644 index dbf5778..0000000 --- a/src/layers/max_pooling.cpp +++ /dev/null @@ -1,67 +0,0 @@ -#include "max_pooling.hpp" -#include - -using namespace CUDANet::Layers; - -MaxPooling2d::MaxPooling2d( - shape2d inputSize, - int nChannels, - shape2d poolingSize, - shape2d stride, - shape2d padding, - ActivationType activationType -) - : inputSize(inputSize), - nChannels(nChannels), - poolingSize(poolingSize), - stride(stride), - padding(padding) { - outputSize = { - (inputSize.first + 2 * padding.first - poolingSize.first) / - stride.first + - 1, - (inputSize.second + 2 * padding.second - poolingSize.second) / - stride.second + - 1 - }; - - activation = new Activation( - activationType, outputSize.first * outputSize.second * nChannels - ); - - #ifdef USE_CUDA - initCUDA(); -#endif -} - -MaxPooling2d::~MaxPooling2d() { -#ifdef USE_CUDA - delCUDA(); -#endif - delete activation; -} - -float* MaxPooling2d::forwardCPU(const float* input) { - throw std::logic_error("Not implemented"); -} - -float* MaxPooling2d::forward(const float* input) { -#ifdef USE_CUDA - return forwardCUDA(input); -#else - return forwardCPU(input); -#endif -} - - -int MaxPooling2d::get_output_size() { - return outputSize.first * outputSize.second * nChannels; -} - -int MaxPooling2d::getInputSize() { - return inputSize.first * inputSize.second * nChannels; -} - -shape2d MaxPooling2d::getOutputDims() { - return outputSize; -} \ No newline at end of file diff --git a/src/layers/output.cu b/src/layers/output.cu deleted file mode 100644 index f62d0f2..0000000 --- a/src/layers/output.cu +++ /dev/null @@ -1,34 +0,0 @@ -#include "output.hpp" -#include - -using namespace CUDANet::Layers; - - -Output::Output(int inputSize) : inputSize(inputSize) { - h_output = (float*) malloc(sizeof(float) * inputSize); -} - -Output::~Output() { - free(h_output); -} - -float* Output::forwardCPU(const float* input) { - throw std::logic_error("Not implemented"); -} - -float* Output::forward(const float* input) { -#ifdef USE_CUDA - return forwardCUDA(input); -#else - return forwardCPU(input); -#endif -} - -int Output::get_output_size() { - return inputSize; -} - - -int Output::getInputSize() { - return inputSize; -} \ No newline at end of file