Implement max pooling test

This commit is contained in:
2024-03-20 21:44:04 +01:00
parent c062e89972
commit dfff0360d9
9 changed files with 134 additions and 32 deletions

View File

@@ -7,6 +7,7 @@ __global__ void Kernels::max_pooling(
const float* __restrict__ d_input,
float* __restrict__ d_output,
const int inputSize,
const int outputSize,
const int nChannels,
const int poolingSize,
const int stride
@@ -15,7 +16,7 @@ __global__ void Kernels::max_pooling(
int i = blockDim.y * blockIdx.y + threadIdx.y;
int c = blockDim.z * blockIdx.z + threadIdx.z;
if (i >= inputSize || j >= inputSize || c >= nChannels) {
if (i >= outputSize || j >= outputSize || c >= nChannels) {
return;
}
@@ -32,13 +33,14 @@ __global__ void Kernels::max_pooling(
}
}
d_output[c * inputSize * inputSize + i * inputSize + j] = max;
d_output[c * outputSize * outputSize + i * outputSize + j] = max;
}
__global__ void Kernels::avg_pooling(
const float* __restrict__ d_input,
float* __restrict__ d_output,
const int inputSize,
const int outputSize,
const int nChannels,
const int poolingSize,
const int stride
@@ -62,6 +64,6 @@ __global__ void Kernels::avg_pooling(
}
}
d_output[c * inputSize * inputSize + i * inputSize + j] =
d_output[c * outputSize * outputSize + i * outputSize + j] =
sum / (poolingSize * poolingSize);
}

View File

@@ -43,7 +43,7 @@ float* AvgPooling2D::forward(const float* d_input) {
);
Kernels::avg_pooling<<<grid, block>>>(
d_input, d_output, inputSize, nChannels, poolingSize, stride
d_input, d_output, inputSize, outputSize, nChannels, poolingSize, stride
);
return d_output;

View File

@@ -15,7 +15,7 @@ MaxPooling2D::MaxPooling2D(
: inputSize(inputSize), nChannels(nChannels), poolingSize(poolingSize), stride(stride) {
outputSize = (inputSize - poolingSize) / stride + 1;
outputSize = (inputSize - 1) / stride + 1;
activation = Activation(
activationType, outputSize * outputSize * nChannels
@@ -46,7 +46,7 @@ float* MaxPooling2D::forward(const float* d_input) {
);
Kernels::max_pooling<<<grid, block>>>(
d_input, d_output, inputSize, nChannels, poolingSize, stride
d_input, d_output, inputSize, outputSize, nChannels, poolingSize, stride
);
return d_output;