Implement avg pooling

This commit is contained in:
2024-03-19 22:33:43 +01:00
parent a0fc1b00ae
commit ef63cbd9f1
6 changed files with 92 additions and 45 deletions

View File

@@ -11,8 +11,7 @@ __global__ void max_pooling(
const int inputSize, const int inputSize,
const int nChannels, const int nChannels,
const int poolingSize, const int poolingSize,
const int stride, const int stride
const int paddingSize
); );
__global__ void avg_pooling( __global__ void avg_pooling(
@@ -21,8 +20,7 @@ __global__ void avg_pooling(
const int inputSize, const int inputSize,
const int nChannels, const int nChannels,
const int poolingSize, const int poolingSize,
const int stride, const int stride
const int paddingSize
); );
} // namespace CUDANet::Kernels } // namespace CUDANet::Kernels

View File

@@ -0,0 +1,38 @@
#ifndef CUDANET_AVG_POOLING_H
#define CUDANET_AVG_POOLING_H
#include "activation.cuh"
#include "layer.cuh"
namespace CUDANet::Layers {
class AvgPooling2D : public SequentialLayer {
public:
AvgPooling2D(
int inputSize,
int nChannels,
int poolingSize,
int stride,
ActivationType activationType
);
~AvgPooling2D();
float* forward(const float* d_input);
private:
int inputSize;
int nChannels;
int poolingSize;
int stride;
int outputSize;
int gridSize;
float* d_output;
Activation activation;
};
} // namespace CUDANet::Layers
#endif // CUDANET_AVG_POOLING_H

View File

@@ -15,7 +15,6 @@ class MaxPooling2D : public SequentialLayer {
int nChannels, int nChannels,
int poolingSize, int poolingSize,
int stride, int stride,
Padding padding,
ActivationType activationType ActivationType activationType
); );
~MaxPooling2D(); ~MaxPooling2D();
@@ -27,7 +26,6 @@ class MaxPooling2D : public SequentialLayer {
int nChannels; int nChannels;
int poolingSize; int poolingSize;
int stride; int stride;
int paddingSize;
int outputSize; int outputSize;
int gridSize; int gridSize;

View File

@@ -10,8 +10,7 @@ __global__ void Kernels::max_pooling(
const int inputSize, const int inputSize,
const int nChannels, const int nChannels,
const int poolingSize, const int poolingSize,
const int stride, const int stride
const int paddingSize
) { ) {
int tid = blockDim.x * blockIdx.x + threadIdx.x; int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= inputSize * inputSize * nChannels) { if (tid >= inputSize * inputSize * nChannels) {
@@ -28,17 +27,9 @@ __global__ void Kernels::max_pooling(
for (int k = 0; k < poolingSize; k++) { for (int k = 0; k < poolingSize; k++) {
for (int l = 0; l < poolingSize; l++) { for (int l = 0; l < poolingSize; l++) {
if (i * stride + k < paddingSize ||
i * stride + k >= (inputSize + paddingSize) ||
j * stride + l < paddingSize ||
j * stride + l >= (inputSize + paddingSize)) {
continue;
}
int inputIndex = c * inputSize * inputSize + int inputIndex = c * inputSize * inputSize +
(i * stride + k - paddingSize) * inputSize + (i * stride + k) * inputSize +
(j * stride + l - paddingSize); (j * stride + l);
if (d_input[inputIndex] > max) { if (d_input[inputIndex] > max) {
max = d_input[inputIndex]; max = d_input[inputIndex];
@@ -55,8 +46,7 @@ __global__ void Kernels::avg_pooling(
const int inputSize, const int inputSize,
const int nChannels, const int nChannels,
const int poolingSize, const int poolingSize,
const int stride, const int stride
const int paddingSize
) { ) {
int tid = blockDim.x * blockIdx.x + threadIdx.x; int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= inputSize * inputSize * nChannels) { if (tid >= inputSize * inputSize * nChannels) {
@@ -73,16 +63,9 @@ __global__ void Kernels::avg_pooling(
for (int k = 0; k < poolingSize; k++) { for (int k = 0; k < poolingSize; k++) {
for (int l = 0; l < poolingSize; l++) { for (int l = 0; l < poolingSize; l++) {
if (i * stride + k < paddingSize ||
i * stride + k >= (inputSize + paddingSize) ||
j * stride + l < paddingSize ||
j * stride + l >= (inputSize + paddingSize)) {
continue;
}
int inputIndex = c * inputSize * inputSize + int inputIndex = c * inputSize * inputSize +
(i * stride + k - paddingSize) * inputSize + (i * stride + k) * inputSize +
(j * stride + l - paddingSize); (j * stride + l);
sum += d_input[inputIndex]; sum += d_input[inputIndex];
} }

44
src/layers/avg_pooling.cu Normal file
View File

@@ -0,0 +1,44 @@
#include "avg_pooling.cuh"
#include "cuda_helper.cuh"
#include "pooling.cuh"
using namespace CUDANet::Layers;
AvgPooling2D::AvgPooling2D(
int inputSize,
int nChannels,
int poolingSize,
int stride,
ActivationType activationType
)
: inputSize(inputSize), nChannels(nChannels), poolingSize(poolingSize), stride(stride) {
outputSize = (inputSize - poolingSize) / stride + 1;
activation = Activation(
activationType, outputSize * outputSize * nChannels
);
d_output = nullptr;
CUDA_CHECK(cudaMalloc(
(void**)&d_output, sizeof(float) * outputSize * outputSize * nChannels
));
gridSize = (outputSize * outputSize * nChannels + BLOCK_SIZE - 1) / BLOCK_SIZE;
}
AvgPooling2D::~AvgPooling2D() {
cudaFree(d_output);
}
float* AvgPooling2D::forward(const float* d_input) {
Kernels::avg_pooling<<<gridSize, BLOCK_SIZE>>>(
d_input, d_output, inputSize, nChannels, poolingSize, stride
);
return d_output;
}

View File

@@ -10,26 +10,12 @@ MaxPooling2D::MaxPooling2D(
int nChannels, int nChannels,
int poolingSize, int poolingSize,
int stride, int stride,
Padding padding,
ActivationType activationType ActivationType activationType
) )
: inputSize(inputSize), nChannels(nChannels), poolingSize(poolingSize), stride(stride) { : inputSize(inputSize), nChannels(nChannels), poolingSize(poolingSize), stride(stride) {
switch (padding) { outputSize = (inputSize - poolingSize) / stride + 1;
case SAME:
outputSize = inputSize;
paddingSize = ((stride - 1) * inputSize - stride + poolingSize) / 2;
break;
case VALID:
paddingSize = 0;
outputSize = (inputSize - poolingSize) / stride + 1;
break;
default:
break;
}
activation = Activation( activation = Activation(
activationType, outputSize * outputSize * nChannels activationType, outputSize * outputSize * nChannels
@@ -52,7 +38,7 @@ MaxPooling2D::~MaxPooling2D() {
float* MaxPooling2D::forward(const float* d_input) { float* MaxPooling2D::forward(const float* d_input) {
Kernels::max_pooling<<<gridSize, BLOCK_SIZE>>>( Kernels::max_pooling<<<gridSize, BLOCK_SIZE>>>(
d_input, d_output, inputSize, nChannels, poolingSize, stride, paddingSize d_input, d_output, inputSize, nChannels, poolingSize, stride
); );
return d_output; return d_output;