mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-05 17:34:21 +00:00
Implement max pooling layer
This commit is contained in:
30
include/kernels/pooling.cuh
Normal file
30
include/kernels/pooling.cuh
Normal file
@@ -0,0 +1,30 @@
|
||||
#ifndef CUDANET_POOLING_H
|
||||
#define CUDANET_POOLING_H
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace CUDANet::Kernels {
|
||||
|
||||
__global__ void max_pooling(
|
||||
const float* __restrict__ d_input,
|
||||
float* __restrict__ d_output,
|
||||
const int inputSize,
|
||||
const int nChannels,
|
||||
const int poolingSize,
|
||||
const int stride,
|
||||
const int paddingSize
|
||||
);
|
||||
|
||||
__global__ void avg_pooling(
|
||||
const float* __restrict__ d_input,
|
||||
float* __restrict__ d_output,
|
||||
const int inputSize,
|
||||
const int nChannels,
|
||||
const int poolingSize,
|
||||
const int stride,
|
||||
const int paddingSize
|
||||
);
|
||||
|
||||
} // namespace CUDANet::Kernels
|
||||
|
||||
#endif // CUDANET_POOLING_H
|
||||
@@ -25,16 +25,17 @@ class Conv2d : public WeightedLayer {
|
||||
* @param stride Convolution stride
|
||||
* @param numFilters Number of output filters
|
||||
* @param padding Padding type ('SAME' or 'VALID')
|
||||
* @param activationType Activation function type ('RELU', 'SIGMOID', 'SOFTMAX' or 'NONE')
|
||||
* @param activationType Activation function type ('RELU', 'SIGMOID',
|
||||
* 'SOFTMAX' or 'NONE')
|
||||
*/
|
||||
Conv2d(
|
||||
int inputSize,
|
||||
int inputChannels,
|
||||
int kernelSize,
|
||||
int stride,
|
||||
int numFilters,
|
||||
Layers::Padding padding,
|
||||
Layers::ActivationType activationType
|
||||
int inputSize,
|
||||
int inputChannels,
|
||||
int kernelSize,
|
||||
int stride,
|
||||
int numFilters,
|
||||
Padding padding,
|
||||
ActivationType activationType
|
||||
);
|
||||
|
||||
/**
|
||||
@@ -107,7 +108,7 @@ class Conv2d : public WeightedLayer {
|
||||
float* d_biases;
|
||||
|
||||
// Kernels
|
||||
Layers::Activation activation;
|
||||
Activation activation;
|
||||
|
||||
/**
|
||||
* @brief Initialize weights of the convolutional layer with zeros
|
||||
|
||||
42
include/layers/max_pooling.cuh
Normal file
42
include/layers/max_pooling.cuh
Normal file
@@ -0,0 +1,42 @@
|
||||
#ifndef CUDANET_MAX_POOLING_H
|
||||
#define CUDANET_MAX_POOLING_H
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "layer.cuh"
|
||||
#include "activation.cuh"
|
||||
|
||||
namespace CUDANet::Layers {
|
||||
|
||||
class MaxPooling2D : public SequentialLayer {
|
||||
public:
|
||||
MaxPooling2D(
|
||||
int inputSize,
|
||||
int nChannels,
|
||||
int poolingSize,
|
||||
int stride,
|
||||
Padding padding,
|
||||
ActivationType activationType
|
||||
);
|
||||
~MaxPooling2D();
|
||||
|
||||
float* forward(const float* d_input);
|
||||
|
||||
private:
|
||||
int inputSize;
|
||||
int nChannels;
|
||||
int poolingSize;
|
||||
int stride;
|
||||
int paddingSize;
|
||||
|
||||
int outputSize;
|
||||
int gridSize;
|
||||
|
||||
float* d_output;
|
||||
|
||||
Activation activation;
|
||||
};
|
||||
|
||||
} // namespace CUDANet::Layers
|
||||
|
||||
#endif // CUDANET_MAX_POOLING_H
|
||||
@@ -3,9 +3,9 @@
|
||||
#include "activation_functions.cuh"
|
||||
#include "cuda_helper.cuh"
|
||||
|
||||
using namespace CUDANet::Kernels;
|
||||
using namespace CUDANet;
|
||||
|
||||
__global__ void sigmoid(
|
||||
__global__ void Kernels::sigmoid(
|
||||
const float* __restrict__ src,
|
||||
float* __restrict__ dst,
|
||||
const unsigned int len
|
||||
@@ -18,7 +18,7 @@ __global__ void sigmoid(
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void relu(
|
||||
__global__ void Kernels::relu(
|
||||
const float* __restrict__ src,
|
||||
float* __restrict__ dst,
|
||||
const unsigned int len
|
||||
@@ -31,7 +31,7 @@ __global__ void relu(
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void softmax_exp(
|
||||
__global__ void Kernels::softmax_exp(
|
||||
const float* __restrict__ src,
|
||||
float* __restrict__ dst,
|
||||
const unsigned int len
|
||||
@@ -44,7 +44,7 @@ __global__ void softmax_exp(
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void softmax_sum(
|
||||
__global__ void Kernels::softmax_sum(
|
||||
const float* __restrict__ d_vector,
|
||||
float* __restrict__ d_output,
|
||||
const unsigned int w
|
||||
@@ -66,7 +66,7 @@ __global__ void softmax_sum(
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void softmax_div(
|
||||
__global__ void Kernels::softmax_div(
|
||||
const float* __restrict__ src,
|
||||
float* __restrict__ dst,
|
||||
const float* __restrict__ sum,
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
#include "convolution.cuh"
|
||||
|
||||
using namespace CUDANet::Kernels;
|
||||
using namespace CUDANet;
|
||||
|
||||
__global__ void convolution(
|
||||
__global__ void Kernels::convolution(
|
||||
const float* __restrict__ d_input,
|
||||
const float* __restrict__ d_kernel,
|
||||
float* __restrict__ d_output,
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#include "cuda_helper.cuh"
|
||||
#include "matmul.cuh"
|
||||
|
||||
using namespace CUDANet::Kernels;
|
||||
using namespace CUDANet;
|
||||
|
||||
__global__ void mat_vec_mul(
|
||||
__global__ void Kernels::mat_vec_mul(
|
||||
const float* __restrict__ d_matrix,
|
||||
const float* __restrict__ d_vector,
|
||||
float* __restrict__ d_output,
|
||||
@@ -37,7 +37,7 @@ __global__ void mat_vec_mul(
|
||||
d_output[tid] = temp;
|
||||
}
|
||||
|
||||
__global__ void vec_vec_add(
|
||||
__global__ void Kernels::vec_vec_add(
|
||||
const float* __restrict__ d_vector1,
|
||||
const float* __restrict__ d_vector2,
|
||||
float* __restrict__ d_output,
|
||||
|
||||
92
src/kernels/pooling.cu
Normal file
92
src/kernels/pooling.cu
Normal file
@@ -0,0 +1,92 @@
|
||||
#include "pooling.cuh"
|
||||
|
||||
#include "cuda_helper.cuh"
|
||||
|
||||
using namespace CUDANet;
|
||||
|
||||
__global__ void Kernels::max_pooling(
|
||||
const float* __restrict__ d_input,
|
||||
float* __restrict__ d_output,
|
||||
const int inputSize,
|
||||
const int nChannels,
|
||||
const int poolingSize,
|
||||
const int stride,
|
||||
const int paddingSize
|
||||
) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= inputSize * inputSize * nChannels) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Get output index
|
||||
int c = tid / (inputSize * inputSize);
|
||||
int i = tid % (inputSize * inputSize) / inputSize;
|
||||
int j = tid % inputSize;
|
||||
|
||||
float max = 0.0f;
|
||||
|
||||
for (int k = 0; k < poolingSize; k++) {
|
||||
for (int l = 0; l < poolingSize; l++) {
|
||||
|
||||
if (i * stride + k < paddingSize ||
|
||||
i * stride + k >= (inputSize + paddingSize) ||
|
||||
j * stride + l < paddingSize ||
|
||||
j * stride + l >= (inputSize + paddingSize)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
int inputIndex = c * inputSize * inputSize +
|
||||
(i * stride + k - paddingSize) * inputSize +
|
||||
(j * stride + l - paddingSize);
|
||||
|
||||
if (d_input[inputIndex] > max) {
|
||||
max = d_input[inputIndex];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d_output[tid] = max;
|
||||
}
|
||||
|
||||
__global__ void Kernels::avg_pooling(
|
||||
const float* __restrict__ d_input,
|
||||
float* __restrict__ d_output,
|
||||
const int inputSize,
|
||||
const int nChannels,
|
||||
const int poolingSize,
|
||||
const int stride,
|
||||
const int paddingSize
|
||||
) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= inputSize * inputSize * nChannels) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Get output index
|
||||
int c = tid / (inputSize * inputSize);
|
||||
int i = tid % (inputSize * inputSize) / inputSize;
|
||||
int j = tid % inputSize;
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int k = 0; k < poolingSize; k++) {
|
||||
for (int l = 0; l < poolingSize; l++) {
|
||||
|
||||
if (i * stride + k < paddingSize ||
|
||||
i * stride + k >= (inputSize + paddingSize) ||
|
||||
j * stride + l < paddingSize ||
|
||||
j * stride + l >= (inputSize + paddingSize)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int inputIndex = c * inputSize * inputSize +
|
||||
(i * stride + k - paddingSize) * inputSize +
|
||||
(j * stride + l - paddingSize);
|
||||
|
||||
sum += d_input[inputIndex];
|
||||
}
|
||||
}
|
||||
|
||||
d_output[tid] = sum / (poolingSize * poolingSize);
|
||||
}
|
||||
@@ -39,7 +39,7 @@ Conv2d::Conv2d(
|
||||
break;
|
||||
}
|
||||
|
||||
activation = Layers::Activation(
|
||||
activation = Activation(
|
||||
activationType, outputSize * outputSize * numFilters
|
||||
);
|
||||
|
||||
|
||||
59
src/layers/max_pooling.cu
Normal file
59
src/layers/max_pooling.cu
Normal file
@@ -0,0 +1,59 @@
|
||||
#include "max_pooling.cuh"
|
||||
#include "cuda_helper.cuh"
|
||||
#include "pooling.cuh"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
|
||||
MaxPooling2D::MaxPooling2D(
|
||||
int inputSize,
|
||||
int nChannels,
|
||||
int poolingSize,
|
||||
int stride,
|
||||
Padding padding,
|
||||
ActivationType activationType
|
||||
)
|
||||
: inputSize(inputSize), nChannels(nChannels), poolingSize(poolingSize), stride(stride) {
|
||||
|
||||
|
||||
switch (padding) {
|
||||
case SAME:
|
||||
outputSize = inputSize;
|
||||
paddingSize = ((stride - 1) * inputSize - stride + poolingSize) / 2;
|
||||
break;
|
||||
|
||||
case VALID:
|
||||
paddingSize = 0;
|
||||
outputSize = (inputSize - poolingSize) / stride + 1;
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
activation = Activation(
|
||||
activationType, outputSize * outputSize * nChannels
|
||||
);
|
||||
|
||||
d_output = nullptr;
|
||||
CUDA_CHECK(cudaMalloc(
|
||||
(void**)&d_output, sizeof(float) * outputSize * outputSize * nChannels
|
||||
));
|
||||
|
||||
gridSize = (outputSize * outputSize * nChannels + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
}
|
||||
|
||||
|
||||
MaxPooling2D::~MaxPooling2D() {
|
||||
cudaFree(d_output);
|
||||
}
|
||||
|
||||
|
||||
float* MaxPooling2D::forward(const float* d_input) {
|
||||
Kernels::max_pooling<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_input, d_output, inputSize, nChannels, poolingSize, stride, paddingSize
|
||||
);
|
||||
|
||||
return d_output;
|
||||
}
|
||||
Reference in New Issue
Block a user