Use 3d memory layout for pooling

This commit is contained in:
2024-03-20 19:17:30 +01:00
parent 5860faf85e
commit c062e89972
3 changed files with 49 additions and 43 deletions

View File

@@ -1,6 +1,5 @@
#include "pooling.cuh"
#include "cuda_helper.cuh" #include "cuda_helper.cuh"
#include "pooling.cuh"
using namespace CUDANet; using namespace CUDANet;
@@ -12,24 +11,20 @@ __global__ void Kernels::max_pooling(
const int poolingSize, const int poolingSize,
const int stride const int stride
) { ) {
int tid = blockDim.x * blockIdx.x + threadIdx.x; int j = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= inputSize * inputSize * nChannels) { int i = blockDim.y * blockIdx.y + threadIdx.y;
int c = blockDim.z * blockIdx.z + threadIdx.z;
if (i >= inputSize || j >= inputSize || c >= nChannels) {
return; return;
} }
// Get output index
int c = tid / (inputSize * inputSize);
int i = tid % (inputSize * inputSize) / inputSize;
int j = tid % inputSize;
float max = 0.0f; float max = 0.0f;
for (int k = 0; k < poolingSize; k++) { for (int k = 0; k < poolingSize; k++) {
for (int l = 0; l < poolingSize; l++) { for (int l = 0; l < poolingSize; l++) {
int inputIndex = c * inputSize * inputSize + int inputIndex = c * inputSize * inputSize +
(i * stride + k) * inputSize + (i * stride + k) * inputSize + (j * stride + l);
(j * stride + l);
if (d_input[inputIndex] > max) { if (d_input[inputIndex] > max) {
max = d_input[inputIndex]; max = d_input[inputIndex];
@@ -37,7 +32,7 @@ __global__ void Kernels::max_pooling(
} }
} }
d_output[tid] = max; d_output[c * inputSize * inputSize + i * inputSize + j] = max;
} }
__global__ void Kernels::avg_pooling( __global__ void Kernels::avg_pooling(
@@ -48,28 +43,25 @@ __global__ void Kernels::avg_pooling(
const int poolingSize, const int poolingSize,
const int stride const int stride
) { ) {
int tid = blockDim.x * blockIdx.x + threadIdx.x; int j = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= inputSize * inputSize * nChannels) { int i = blockDim.y * blockIdx.y + threadIdx.y;
int c = blockDim.z * blockIdx.z + threadIdx.z;
if (i >= inputSize || j >= inputSize || c >= nChannels) {
return; return;
} }
// Get output index
int c = tid / (inputSize * inputSize);
int i = tid % (inputSize * inputSize) / inputSize;
int j = tid % inputSize;
float sum = 0.0f; float sum = 0.0f;
for (int k = 0; k < poolingSize; k++) { for (int k = 0; k < poolingSize; k++) {
for (int l = 0; l < poolingSize; l++) { for (int l = 0; l < poolingSize; l++) {
int inputIndex = c * inputSize * inputSize + int inputIndex = c * inputSize * inputSize +
(i * stride + k) * inputSize + (i * stride + k) * inputSize + (j * stride + l);
(j * stride + l);
sum += d_input[inputIndex]; sum += d_input[inputIndex];
} }
} }
d_output[tid] = sum / (poolingSize * poolingSize); d_output[c * inputSize * inputSize + i * inputSize + j] =
sum / (poolingSize * poolingSize);
} }

View File

@@ -11,32 +11,38 @@ AvgPooling2D::AvgPooling2D(
int stride, int stride,
ActivationType activationType ActivationType activationType
) )
: inputSize(inputSize), nChannels(nChannels), poolingSize(poolingSize), stride(stride) { : inputSize(inputSize),
nChannels(nChannels),
poolingSize(poolingSize),
stride(stride) {
outputSize = (inputSize - poolingSize) / stride + 1; outputSize = (inputSize - poolingSize) / stride + 1;
activation = Activation( activation =
activationType, outputSize * outputSize * nChannels Activation(activationType, outputSize * outputSize * nChannels);
);
d_output = nullptr; d_output = nullptr;
CUDA_CHECK(cudaMalloc( CUDA_CHECK(cudaMalloc(
(void**)&d_output, sizeof(float) * outputSize * outputSize * nChannels (void**)&d_output, sizeof(float) * outputSize * outputSize * nChannels
)); ));
gridSize = (outputSize * outputSize * nChannels + BLOCK_SIZE - 1) / BLOCK_SIZE; gridSize =
(outputSize * outputSize * nChannels + BLOCK_SIZE - 1) / BLOCK_SIZE;
} }
AvgPooling2D::~AvgPooling2D() { AvgPooling2D::~AvgPooling2D() {
cudaFree(d_output); cudaFree(d_output);
} }
float* AvgPooling2D::forward(const float* d_input) { float* AvgPooling2D::forward(const float* d_input) {
Kernels::avg_pooling<<<gridSize, BLOCK_SIZE>>>(
dim3 block(8, 8, 8);
dim3 grid(
(outputSize + block.x - 1) / block.x,
(outputSize + block.y - 1) / block.y,
(nChannels + block.z - 1) / block.z
);
Kernels::avg_pooling<<<grid, block>>>(
d_input, d_output, inputSize, nChannels, poolingSize, stride d_input, d_output, inputSize, nChannels, poolingSize, stride
); );

View File

@@ -37,7 +37,15 @@ MaxPooling2D::~MaxPooling2D() {
float* MaxPooling2D::forward(const float* d_input) { float* MaxPooling2D::forward(const float* d_input) {
Kernels::max_pooling<<<gridSize, BLOCK_SIZE>>>(
dim3 block(8,8,8);
dim3 grid(
(outputSize + block.x - 1) / block.x,
(outputSize + block.y - 1) / block.y,
(nChannels + block.z - 1) / block.z
);
Kernels::max_pooling<<<grid, block>>>(
d_input, d_output, inputSize, nChannels, poolingSize, stride d_input, d_output, inputSize, nChannels, poolingSize, stride
); );