Implement avg pooling

This commit is contained in:
2024-03-19 22:33:43 +01:00
parent a0fc1b00ae
commit ef63cbd9f1
6 changed files with 92 additions and 45 deletions

View File

@@ -10,8 +10,7 @@ __global__ void Kernels::max_pooling(
const int inputSize,
const int nChannels,
const int poolingSize,
const int stride,
const int paddingSize
const int stride
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= inputSize * inputSize * nChannels) {
@@ -28,17 +27,9 @@ __global__ void Kernels::max_pooling(
for (int k = 0; k < poolingSize; k++) {
for (int l = 0; l < poolingSize; l++) {
if (i * stride + k < paddingSize ||
i * stride + k >= (inputSize + paddingSize) ||
j * stride + l < paddingSize ||
j * stride + l >= (inputSize + paddingSize)) {
continue;
}
int inputIndex = c * inputSize * inputSize +
(i * stride + k - paddingSize) * inputSize +
(j * stride + l - paddingSize);
(i * stride + k) * inputSize +
(j * stride + l);
if (d_input[inputIndex] > max) {
max = d_input[inputIndex];
@@ -55,8 +46,7 @@ __global__ void Kernels::avg_pooling(
const int inputSize,
const int nChannels,
const int poolingSize,
const int stride,
const int paddingSize
const int stride
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= inputSize * inputSize * nChannels) {
@@ -73,16 +63,9 @@ __global__ void Kernels::avg_pooling(
for (int k = 0; k < poolingSize; k++) {
for (int l = 0; l < poolingSize; l++) {
if (i * stride + k < paddingSize ||
i * stride + k >= (inputSize + paddingSize) ||
j * stride + l < paddingSize ||
j * stride + l >= (inputSize + paddingSize)) {
continue;
}
int inputIndex = c * inputSize * inputSize +
(i * stride + k - paddingSize) * inputSize +
(j * stride + l - paddingSize);
(i * stride + k) * inputSize +
(j * stride + l);
sum += d_input[inputIndex];
}