mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-24 07:14:22 +00:00
Migrate MaxPool2d layer to Tensors
This commit is contained in:
@@ -4,35 +4,34 @@
|
||||
|
||||
using namespace CUDANet;
|
||||
|
||||
__global__ void Kernels::max_pooling(
|
||||
__global__ void Kernels::max_pool(
|
||||
const float* __restrict__ d_input,
|
||||
float* __restrict__ d_output,
|
||||
const shape2d inputSize,
|
||||
const shape2d outputSize,
|
||||
const int nChannels,
|
||||
const shape2d poolingSize,
|
||||
const shape2d stride,
|
||||
const shape2d padding
|
||||
const Shape input_shape,
|
||||
const Shape output_shape,
|
||||
const Shape pool_shape,
|
||||
const Shape stride_shape,
|
||||
const Shape padding_shape
|
||||
) {
|
||||
int j = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
int c = blockDim.z * blockIdx.z + threadIdx.z;
|
||||
|
||||
if (i >= outputSize.first || j >= outputSize.second || c >= nChannels) {
|
||||
if (i >= output_shape[0] || j >= output_shape[1] || c >= output_shape[2]) {
|
||||
return;
|
||||
}
|
||||
|
||||
float max = 0.0f;
|
||||
|
||||
for (int k = 0; k < poolingSize.first; k++) {
|
||||
for (int l = 0; l < poolingSize.second; l++) {
|
||||
int inputRow = i * stride.first + k - padding.first;
|
||||
int inputCol = j * stride.second + l - padding.second;
|
||||
for (int k = 0; k < pool_shape[0]; k++) {
|
||||
for (int l = 0; l < pool_shape[1]; l++) {
|
||||
int inputRow = i * stride_shape[0] + k - padding_shape[0];
|
||||
int inputCol = j * stride_shape[1] + l - padding_shape[1];
|
||||
|
||||
if (inputRow >= 0 && inputRow < inputSize.first && inputCol >= 0 &&
|
||||
inputCol < inputSize.second) {
|
||||
int inputIndex = c * inputSize.first * inputSize.second +
|
||||
inputRow * inputSize.second + inputCol;
|
||||
if (inputRow >= 0 && inputRow < input_shape[0] && inputCol >= 0 &&
|
||||
inputCol < input_shape[1]) {
|
||||
int inputIndex = c * input_shape[0] * input_shape[1] +
|
||||
inputRow * input_shape[1] + inputCol;
|
||||
if (d_input[inputIndex] > max) {
|
||||
max = d_input[inputIndex];
|
||||
}
|
||||
@@ -41,45 +40,44 @@ __global__ void Kernels::max_pooling(
|
||||
}
|
||||
|
||||
d_output
|
||||
[c * outputSize.first * outputSize.second + i * outputSize.second + j] =
|
||||
[c * output_shape[0] * output_shape[1] + i * output_shape[1] + j] =
|
||||
max;
|
||||
}
|
||||
|
||||
__global__ void Kernels::avg_pooling(
|
||||
__global__ void Kernels::avg_pool(
|
||||
const float* __restrict__ d_input,
|
||||
float* __restrict__ d_output,
|
||||
const shape2d inputSize,
|
||||
const shape2d outputSize,
|
||||
const int nChannels,
|
||||
const shape2d poolingSize,
|
||||
const shape2d stride,
|
||||
const shape2d padding
|
||||
const Shape input_shape,
|
||||
const Shape output_shape,
|
||||
const Shape pool_shape,
|
||||
const Shape stride_shape,
|
||||
const Shape padding_shape
|
||||
) {
|
||||
int j = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
int c = blockDim.z * blockIdx.z + threadIdx.z;
|
||||
|
||||
if (i >= outputSize.first || j >= outputSize.second || c >= nChannels) {
|
||||
if (i >= output_shape[0] || j >= output_shape[1] || c >= output_shape[2]) {
|
||||
return;
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int k = 0; k < poolingSize.first; k++) {
|
||||
for (int l = 0; l < poolingSize.second; l++) {
|
||||
int inputRow = i * stride.first + k - padding.first;
|
||||
int inputCol = j * stride.second + l - padding.second;
|
||||
for (int k = 0; k < pool_shape[0]; k++) {
|
||||
for (int l = 0; l < pool_shape[1]; l++) {
|
||||
int inputRow = i * stride_shape[0] + k - padding_shape[0];
|
||||
int inputCol = j * stride_shape[1] + l - padding_shape[1];
|
||||
|
||||
if (inputRow >= 0 && inputRow < inputSize.first && inputCol >= 0 &&
|
||||
inputCol < inputSize.second) {
|
||||
int inputIndex = c * inputSize.first * inputSize.second +
|
||||
inputRow * inputSize.second + inputCol;
|
||||
if (inputRow >= 0 && inputRow < input_shape[0] && inputCol >= 0 &&
|
||||
inputCol < input_shape[1]) {
|
||||
int inputIndex = c * input_shape[0] * input_shape[1] +
|
||||
inputRow * input_shape[1] + inputCol;
|
||||
sum += d_input[inputIndex];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d_output
|
||||
[c * outputSize.first * outputSize.second + i * outputSize.second + j] =
|
||||
sum / (poolingSize.first * poolingSize.second);
|
||||
[c * output_shape[0] * output_shape[1] + i * output_shape[1] + j] =
|
||||
sum / (pool_shape[0] * pool_shape[1]);
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "kernels/activation_functions.cuh"
|
||||
#include "kernels/convolution.cuh"
|
||||
#include "kernels/matmul.cuh"
|
||||
#include "kernels/pooling.cuh"
|
||||
#include "utils/cuda_helper.cuh"
|
||||
|
||||
using namespace CUDANet::Backend;
|
||||
@@ -108,5 +109,31 @@ CUDANet::Tensor& CUDA::conv2d(
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
CUDANet::Tensor& CUDA::maxPool2d(
|
||||
const CUDANet::Tensor& input,
|
||||
CUDANet::Tensor& output,
|
||||
CUDANet::Shape input_shape,
|
||||
CUDANet::Shape pool_shape,
|
||||
CUDANet::Shape stride_shape,
|
||||
CUDANet::Shape padding_shape,
|
||||
CUDANet::Shape output_shape
|
||||
) {
|
||||
dim3 block(8, 8, 8);
|
||||
dim3 grid(
|
||||
(output_shape[0] + block.x - 1) / block.x,
|
||||
(output_shape[1] + block.y - 1) / block.y,
|
||||
(output_shape[2] + block.z - 1) / block.z
|
||||
);
|
||||
|
||||
Kernels::max_pool<<<grid, block>>>(
|
||||
input.data<float>(), output.data<float>(), input_shape, output_shape, pool_shape,
|
||||
stride_shape, padding_shape
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return output;
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
#include "cuda_helper.cuh"
|
||||
#include "input.hpp"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
void Input::initCUDA() {
|
||||
d_output = nullptr;
|
||||
CUDA_CHECK(cudaMalloc((void**)&d_output, sizeof(float) * inputSize));
|
||||
}
|
||||
|
||||
void Input::delCUDA() {
|
||||
cudaFree(d_output);
|
||||
}
|
||||
|
||||
float* Input::forwardCUDA(const float* input) {
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_output, input, sizeof(float) * inputSize, cudaMemcpyHostToDevice
|
||||
));
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return d_output;
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
#include "cuda_helper.cuh"
|
||||
#include "max_pooling.hpp"
|
||||
#include "pooling.cuh"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
void MaxPooling2d::initCUDA() {
|
||||
d_output = nullptr;
|
||||
CUDA_CHECK(cudaMalloc(
|
||||
(void**)&d_output,
|
||||
sizeof(float) * outputSize.first * outputSize.second * nChannels
|
||||
));
|
||||
}
|
||||
|
||||
void MaxPooling2d::delCUDA() {
|
||||
cudaFree(d_output);
|
||||
}
|
||||
|
||||
|
||||
float* MaxPooling2d::forwardCUDA(const float* d_input) {
|
||||
dim3 block(8, 8, 8);
|
||||
dim3 grid(
|
||||
(outputSize.first + block.x - 1) / block.x,
|
||||
(outputSize.second + block.y - 1) / block.y,
|
||||
(nChannels + block.z - 1) / block.z
|
||||
);
|
||||
|
||||
Kernels::max_pooling<<<grid, block>>>(
|
||||
d_input, d_output, inputSize, outputSize, nChannels, poolingSize,
|
||||
stride, padding
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
activation->activate(d_output);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return d_output;
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
#include "output.hpp"
|
||||
|
||||
#include "cuda_helper.cuh"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
float* Output::forwardCUDA(const float* input) {
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
h_output, input, sizeof(float) * inputSize, cudaMemcpyDeviceToHost
|
||||
));
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return h_output;
|
||||
}
|
||||
Reference in New Issue
Block a user