mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-22 14:24:22 +00:00
Migrate MaxPool2d layer to Tensors
This commit is contained in:
@@ -52,6 +52,16 @@ class Backend {
|
|||||||
const CUDANet::Shape stride_shape,
|
const CUDANet::Shape stride_shape,
|
||||||
const CUDANet::Shape out_shape
|
const CUDANet::Shape out_shape
|
||||||
) = 0;
|
) = 0;
|
||||||
|
|
||||||
|
virtual CUDANet::Tensor& maxPool2d(
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
CUDANet::Shape input_shape,
|
||||||
|
CUDANet::Shape pool_shape,
|
||||||
|
CUDANet::Shape stride_shape,
|
||||||
|
CUDANet::Shape padding_shape,
|
||||||
|
CUDANet::Shape output_shape
|
||||||
|
) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace CUDANet
|
} // namespace CUDANet
|
||||||
@@ -48,6 +48,16 @@ class CUDA : public Backend {
|
|||||||
const CUDANet::Shape stride_shape,
|
const CUDANet::Shape stride_shape,
|
||||||
const CUDANet::Shape out_shape
|
const CUDANet::Shape out_shape
|
||||||
) override;
|
) override;
|
||||||
|
|
||||||
|
CUDANet::Tensor& CUDA::maxPool2d(
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
CUDANet::Shape input_shape,
|
||||||
|
CUDANet::Shape pool_shape,
|
||||||
|
CUDANet::Shape stride_shape,
|
||||||
|
CUDANet::Shape padding_shape,
|
||||||
|
CUDANet::Shape output_shape
|
||||||
|
) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace CUDANet::Backend
|
} // namespace CUDANet::Backend
|
||||||
@@ -1,33 +1,28 @@
|
|||||||
#ifndef CUDANET_POOLING_H
|
#pragma once
|
||||||
#define CUDANET_POOLING_H
|
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include "layer.hpp"
|
#include "layer.hpp"
|
||||||
|
|
||||||
namespace CUDANet::Kernels {
|
namespace CUDANet::Kernels {
|
||||||
|
|
||||||
__global__ void max_pooling(
|
__global__ void max_pool(
|
||||||
const float* __restrict__ d_input,
|
const float* __restrict__ d_input,
|
||||||
float* __restrict__ d_output,
|
float* __restrict__ d_output,
|
||||||
const shape2d inputSize,
|
const Shape input_shape,
|
||||||
const shape2d outputSize,
|
const Shape output_shape,
|
||||||
const int nChannels,
|
const Shape pool_shape,
|
||||||
const shape2d poolingSize,
|
const Shape stride_shape,
|
||||||
const shape2d stride,
|
const Shape padding_shape
|
||||||
const shape2d padding
|
|
||||||
);
|
);
|
||||||
|
|
||||||
__global__ void avg_pooling(
|
__global__ void avg_pool(
|
||||||
const float* __restrict__ d_input,
|
const float* __restrict__ d_input,
|
||||||
float* __restrict__ d_output,
|
float* __restrict__ d_output,
|
||||||
const shape2d inputSize,
|
const Shape input_shape,
|
||||||
const shape2d outputSize,
|
const Shape output_shape,
|
||||||
const int nChannels,
|
const Shape pool_shape,
|
||||||
const shape2d poolingSize,
|
const Shape stride_shape,
|
||||||
const shape2d stride,
|
const Shape padding_shape
|
||||||
const shape2d padding
|
|
||||||
);
|
);
|
||||||
|
|
||||||
} // namespace CUDANet::Kernels
|
} // namespace CUDANet::Kernels
|
||||||
|
|
||||||
#endif // CUDANET_POOLING_H
|
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
#ifndef CUDANET_INPUT_LAYER_H
|
|
||||||
#define CUDANET_INPUT_LAYER_H
|
|
||||||
|
|
||||||
#include "layer.hpp"
|
|
||||||
|
|
||||||
namespace CUDANet::Layers {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Input layer, just copies the input to the device
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
class Input : public Layer {
|
|
||||||
public:
|
|
||||||
/**
|
|
||||||
* @brief Create a new Input layer
|
|
||||||
*
|
|
||||||
* @param inputSize Size of the input vector
|
|
||||||
*/
|
|
||||||
explicit Input(int inputSize);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Destroy the Input layer
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
~Input();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Forward pass of the input layer. Just copies the input to the
|
|
||||||
* device
|
|
||||||
*
|
|
||||||
* @param input Host pointer to the input vector
|
|
||||||
* @return Device pointer to the output vector
|
|
||||||
*/
|
|
||||||
float* forward(const float* input);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get output size
|
|
||||||
*
|
|
||||||
* @return int output size
|
|
||||||
*/
|
|
||||||
int get_output_size();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get input size
|
|
||||||
*
|
|
||||||
* @return int input size
|
|
||||||
*/
|
|
||||||
int getInputSize();
|
|
||||||
|
|
||||||
private:
|
|
||||||
int inputSize;
|
|
||||||
|
|
||||||
float* forwardCPU(const float* input);
|
|
||||||
|
|
||||||
#ifdef USE_CUDA
|
|
||||||
float* d_output;
|
|
||||||
|
|
||||||
float* forwardCUDA(const float* input);
|
|
||||||
void initCUDA();
|
|
||||||
void delCUDA();
|
|
||||||
#endif
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace CUDANet::Layers
|
|
||||||
|
|
||||||
#endif // CUDANET_INPUT_LAYER_H
|
|
||||||
51
include/layers/max_pool.hpp
Normal file
51
include/layers/max_pool.hpp
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "layer.hpp"
|
||||||
|
|
||||||
|
namespace CUDANet::Layers {
|
||||||
|
|
||||||
|
class MaxPool2d : public Layer {
|
||||||
|
public:
|
||||||
|
MaxPool2d(
|
||||||
|
CUDANet::Shape input_shape,
|
||||||
|
CUDANet::Shape pooling_shape,
|
||||||
|
CUDANet::Shape stride_shape,
|
||||||
|
CUDANet::Shape padding_shape,
|
||||||
|
CUDANet::Backend* backend
|
||||||
|
);
|
||||||
|
~MaxPool2d();
|
||||||
|
|
||||||
|
CUDANet::Tensor& forward(CUDANet::Tensor &input) override;
|
||||||
|
|
||||||
|
CUDANet::Shape input_shape() override;
|
||||||
|
|
||||||
|
CUDANet::Shape output_shape() override;
|
||||||
|
|
||||||
|
size_t input_size() override;
|
||||||
|
|
||||||
|
size_t output_size() override;
|
||||||
|
|
||||||
|
void set_weights(void *input) override;
|
||||||
|
|
||||||
|
CUDANet::Tensor& get_weights() override;
|
||||||
|
|
||||||
|
void set_biases(void *input) override;
|
||||||
|
|
||||||
|
CUDANet::Tensor& get_biases() override;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
private:
|
||||||
|
CUDANet::Shape in_shape;
|
||||||
|
|
||||||
|
CUDANet::Shape pooling_shape;
|
||||||
|
CUDANet::Shape stride_shape;
|
||||||
|
CUDANet::Shape padding_shape;
|
||||||
|
|
||||||
|
CUDANet::Shape out_shape;
|
||||||
|
CUDANet::Tensor output;
|
||||||
|
|
||||||
|
CUDANet::Backend *backend;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace CUDANet::Layers
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
#ifndef CUDANET_MAX_POOLING_H
|
|
||||||
#define CUDANET_MAX_POOLING_H
|
|
||||||
|
|
||||||
#include "activation.hpp"
|
|
||||||
#include "layer.hpp"
|
|
||||||
|
|
||||||
namespace CUDANet::Layers {
|
|
||||||
|
|
||||||
class MaxPooling2d : public Layer, public TwoDLayer {
|
|
||||||
public:
|
|
||||||
MaxPooling2d(
|
|
||||||
shape2d inputSize,
|
|
||||||
int nChannels,
|
|
||||||
shape2d poolingSize,
|
|
||||||
shape2d stride,
|
|
||||||
shape2d padding,
|
|
||||||
ActivationType activationType
|
|
||||||
);
|
|
||||||
~MaxPooling2d();
|
|
||||||
|
|
||||||
float* forward(const float* input);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get output size
|
|
||||||
*
|
|
||||||
* @return int output size
|
|
||||||
*/
|
|
||||||
int get_output_size();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get input size
|
|
||||||
*
|
|
||||||
* @return int input size
|
|
||||||
*/
|
|
||||||
int getInputSize();
|
|
||||||
|
|
||||||
shape2d getOutputDims();
|
|
||||||
|
|
||||||
private:
|
|
||||||
shape2d inputSize;
|
|
||||||
int nChannels;
|
|
||||||
shape2d poolingSize;
|
|
||||||
shape2d stride;
|
|
||||||
shape2d padding;
|
|
||||||
|
|
||||||
shape2d outputSize;
|
|
||||||
|
|
||||||
Activation* activation;
|
|
||||||
|
|
||||||
float* forwardCPU(const float* input);
|
|
||||||
|
|
||||||
#ifdef USE_CUDA
|
|
||||||
float* d_output;
|
|
||||||
float* forwardCUDA(const float* d_input);
|
|
||||||
|
|
||||||
void initCUDA();
|
|
||||||
void delCUDA();
|
|
||||||
#endif
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace CUDANet::Layers
|
|
||||||
|
|
||||||
#endif // CUDANET_MAX_POOLING_H
|
|
||||||
@@ -1,59 +0,0 @@
|
|||||||
#ifndef CUDANET_OUTPUT_LAYER_H
|
|
||||||
#define CUDANET_OUTPUT_LAYER_H
|
|
||||||
|
|
||||||
#include "layer.hpp"
|
|
||||||
|
|
||||||
namespace CUDANet::Layers {
|
|
||||||
|
|
||||||
class Output : public Layer {
|
|
||||||
public:
|
|
||||||
/**
|
|
||||||
* @brief Create a new Output layer
|
|
||||||
*
|
|
||||||
* @param inputSize Size of the input vector
|
|
||||||
*/
|
|
||||||
explicit Output(int inputSize);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Destroy the Output layer
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
~Output();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Forward pass of the output layer. Just copies the input from
|
|
||||||
* device to host
|
|
||||||
*
|
|
||||||
* @param input Device pointer to the input vector
|
|
||||||
* @return Host pointer to the output vector
|
|
||||||
*/
|
|
||||||
float* forward(const float* input);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get output size
|
|
||||||
*
|
|
||||||
* @return int output size
|
|
||||||
*/
|
|
||||||
int get_output_size();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get input size
|
|
||||||
*
|
|
||||||
* @return int input size
|
|
||||||
*/
|
|
||||||
int getInputSize();
|
|
||||||
|
|
||||||
private:
|
|
||||||
int inputSize;
|
|
||||||
float* h_output;
|
|
||||||
|
|
||||||
float* forwardCPU(const float* input);
|
|
||||||
|
|
||||||
#ifdef USE_CUDA
|
|
||||||
float* forwardCUDA(const float* input);
|
|
||||||
#endif
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace CUDANet::Layers
|
|
||||||
|
|
||||||
#endif // CUDANET_OUTPUT_LAYER_H
|
|
||||||
@@ -4,35 +4,34 @@
|
|||||||
|
|
||||||
using namespace CUDANet;
|
using namespace CUDANet;
|
||||||
|
|
||||||
__global__ void Kernels::max_pooling(
|
__global__ void Kernels::max_pool(
|
||||||
const float* __restrict__ d_input,
|
const float* __restrict__ d_input,
|
||||||
float* __restrict__ d_output,
|
float* __restrict__ d_output,
|
||||||
const shape2d inputSize,
|
const Shape input_shape,
|
||||||
const shape2d outputSize,
|
const Shape output_shape,
|
||||||
const int nChannels,
|
const Shape pool_shape,
|
||||||
const shape2d poolingSize,
|
const Shape stride_shape,
|
||||||
const shape2d stride,
|
const Shape padding_shape
|
||||||
const shape2d padding
|
|
||||||
) {
|
) {
|
||||||
int j = blockDim.x * blockIdx.x + threadIdx.x;
|
int j = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
||||||
int c = blockDim.z * blockIdx.z + threadIdx.z;
|
int c = blockDim.z * blockIdx.z + threadIdx.z;
|
||||||
|
|
||||||
if (i >= outputSize.first || j >= outputSize.second || c >= nChannels) {
|
if (i >= output_shape[0] || j >= output_shape[1] || c >= output_shape[2]) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
float max = 0.0f;
|
float max = 0.0f;
|
||||||
|
|
||||||
for (int k = 0; k < poolingSize.first; k++) {
|
for (int k = 0; k < pool_shape[0]; k++) {
|
||||||
for (int l = 0; l < poolingSize.second; l++) {
|
for (int l = 0; l < pool_shape[1]; l++) {
|
||||||
int inputRow = i * stride.first + k - padding.first;
|
int inputRow = i * stride_shape[0] + k - padding_shape[0];
|
||||||
int inputCol = j * stride.second + l - padding.second;
|
int inputCol = j * stride_shape[1] + l - padding_shape[1];
|
||||||
|
|
||||||
if (inputRow >= 0 && inputRow < inputSize.first && inputCol >= 0 &&
|
if (inputRow >= 0 && inputRow < input_shape[0] && inputCol >= 0 &&
|
||||||
inputCol < inputSize.second) {
|
inputCol < input_shape[1]) {
|
||||||
int inputIndex = c * inputSize.first * inputSize.second +
|
int inputIndex = c * input_shape[0] * input_shape[1] +
|
||||||
inputRow * inputSize.second + inputCol;
|
inputRow * input_shape[1] + inputCol;
|
||||||
if (d_input[inputIndex] > max) {
|
if (d_input[inputIndex] > max) {
|
||||||
max = d_input[inputIndex];
|
max = d_input[inputIndex];
|
||||||
}
|
}
|
||||||
@@ -41,45 +40,44 @@ __global__ void Kernels::max_pooling(
|
|||||||
}
|
}
|
||||||
|
|
||||||
d_output
|
d_output
|
||||||
[c * outputSize.first * outputSize.second + i * outputSize.second + j] =
|
[c * output_shape[0] * output_shape[1] + i * output_shape[1] + j] =
|
||||||
max;
|
max;
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Kernels::avg_pooling(
|
__global__ void Kernels::avg_pool(
|
||||||
const float* __restrict__ d_input,
|
const float* __restrict__ d_input,
|
||||||
float* __restrict__ d_output,
|
float* __restrict__ d_output,
|
||||||
const shape2d inputSize,
|
const Shape input_shape,
|
||||||
const shape2d outputSize,
|
const Shape output_shape,
|
||||||
const int nChannels,
|
const Shape pool_shape,
|
||||||
const shape2d poolingSize,
|
const Shape stride_shape,
|
||||||
const shape2d stride,
|
const Shape padding_shape
|
||||||
const shape2d padding
|
|
||||||
) {
|
) {
|
||||||
int j = blockDim.x * blockIdx.x + threadIdx.x;
|
int j = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
||||||
int c = blockDim.z * blockIdx.z + threadIdx.z;
|
int c = blockDim.z * blockIdx.z + threadIdx.z;
|
||||||
|
|
||||||
if (i >= outputSize.first || j >= outputSize.second || c >= nChannels) {
|
if (i >= output_shape[0] || j >= output_shape[1] || c >= output_shape[2]) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
float sum = 0.0f;
|
float sum = 0.0f;
|
||||||
|
|
||||||
for (int k = 0; k < poolingSize.first; k++) {
|
for (int k = 0; k < pool_shape[0]; k++) {
|
||||||
for (int l = 0; l < poolingSize.second; l++) {
|
for (int l = 0; l < pool_shape[1]; l++) {
|
||||||
int inputRow = i * stride.first + k - padding.first;
|
int inputRow = i * stride_shape[0] + k - padding_shape[0];
|
||||||
int inputCol = j * stride.second + l - padding.second;
|
int inputCol = j * stride_shape[1] + l - padding_shape[1];
|
||||||
|
|
||||||
if (inputRow >= 0 && inputRow < inputSize.first && inputCol >= 0 &&
|
if (inputRow >= 0 && inputRow < input_shape[0] && inputCol >= 0 &&
|
||||||
inputCol < inputSize.second) {
|
inputCol < input_shape[1]) {
|
||||||
int inputIndex = c * inputSize.first * inputSize.second +
|
int inputIndex = c * input_shape[0] * input_shape[1] +
|
||||||
inputRow * inputSize.second + inputCol;
|
inputRow * input_shape[1] + inputCol;
|
||||||
sum += d_input[inputIndex];
|
sum += d_input[inputIndex];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
d_output
|
d_output
|
||||||
[c * outputSize.first * outputSize.second + i * outputSize.second + j] =
|
[c * output_shape[0] * output_shape[1] + i * output_shape[1] + j] =
|
||||||
sum / (poolingSize.first * poolingSize.second);
|
sum / (pool_shape[0] * pool_shape[1]);
|
||||||
}
|
}
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
#include "kernels/activation_functions.cuh"
|
#include "kernels/activation_functions.cuh"
|
||||||
#include "kernels/convolution.cuh"
|
#include "kernels/convolution.cuh"
|
||||||
#include "kernels/matmul.cuh"
|
#include "kernels/matmul.cuh"
|
||||||
|
#include "kernels/pooling.cuh"
|
||||||
#include "utils/cuda_helper.cuh"
|
#include "utils/cuda_helper.cuh"
|
||||||
|
|
||||||
using namespace CUDANet::Backend;
|
using namespace CUDANet::Backend;
|
||||||
@@ -110,3 +111,29 @@ CUDANet::Tensor& CUDA::conv2d(
|
|||||||
|
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CUDANet::Tensor& CUDA::maxPool2d(
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
CUDANet::Shape input_shape,
|
||||||
|
CUDANet::Shape pool_shape,
|
||||||
|
CUDANet::Shape stride_shape,
|
||||||
|
CUDANet::Shape padding_shape,
|
||||||
|
CUDANet::Shape output_shape
|
||||||
|
) {
|
||||||
|
dim3 block(8, 8, 8);
|
||||||
|
dim3 grid(
|
||||||
|
(output_shape[0] + block.x - 1) / block.x,
|
||||||
|
(output_shape[1] + block.y - 1) / block.y,
|
||||||
|
(output_shape[2] + block.z - 1) / block.z
|
||||||
|
);
|
||||||
|
|
||||||
|
Kernels::max_pool<<<grid, block>>>(
|
||||||
|
input.data<float>(), output.data<float>(), input_shape, output_shape, pool_shape,
|
||||||
|
stride_shape, padding_shape
|
||||||
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
#include "cuda_helper.cuh"
|
|
||||||
#include "input.hpp"
|
|
||||||
|
|
||||||
using namespace CUDANet::Layers;
|
|
||||||
|
|
||||||
void Input::initCUDA() {
|
|
||||||
d_output = nullptr;
|
|
||||||
CUDA_CHECK(cudaMalloc((void**)&d_output, sizeof(float) * inputSize));
|
|
||||||
}
|
|
||||||
|
|
||||||
void Input::delCUDA() {
|
|
||||||
cudaFree(d_output);
|
|
||||||
}
|
|
||||||
|
|
||||||
float* Input::forwardCUDA(const float* input) {
|
|
||||||
CUDA_CHECK(cudaMemcpy(
|
|
||||||
d_output, input, sizeof(float) * inputSize, cudaMemcpyHostToDevice
|
|
||||||
));
|
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
|
||||||
|
|
||||||
return d_output;
|
|
||||||
}
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
#include "cuda_helper.cuh"
|
|
||||||
#include "max_pooling.hpp"
|
|
||||||
#include "pooling.cuh"
|
|
||||||
|
|
||||||
using namespace CUDANet::Layers;
|
|
||||||
|
|
||||||
void MaxPooling2d::initCUDA() {
|
|
||||||
d_output = nullptr;
|
|
||||||
CUDA_CHECK(cudaMalloc(
|
|
||||||
(void**)&d_output,
|
|
||||||
sizeof(float) * outputSize.first * outputSize.second * nChannels
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
void MaxPooling2d::delCUDA() {
|
|
||||||
cudaFree(d_output);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
float* MaxPooling2d::forwardCUDA(const float* d_input) {
|
|
||||||
dim3 block(8, 8, 8);
|
|
||||||
dim3 grid(
|
|
||||||
(outputSize.first + block.x - 1) / block.x,
|
|
||||||
(outputSize.second + block.y - 1) / block.y,
|
|
||||||
(nChannels + block.z - 1) / block.z
|
|
||||||
);
|
|
||||||
|
|
||||||
Kernels::max_pooling<<<grid, block>>>(
|
|
||||||
d_input, d_output, inputSize, outputSize, nChannels, poolingSize,
|
|
||||||
stride, padding
|
|
||||||
);
|
|
||||||
CUDA_CHECK(cudaGetLastError());
|
|
||||||
|
|
||||||
activation->activate(d_output);
|
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
|
||||||
|
|
||||||
return d_output;
|
|
||||||
}
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
#include "output.hpp"
|
|
||||||
|
|
||||||
#include "cuda_helper.cuh"
|
|
||||||
|
|
||||||
using namespace CUDANet::Layers;
|
|
||||||
|
|
||||||
float* Output::forwardCUDA(const float* input) {
|
|
||||||
CUDA_CHECK(cudaMemcpy(
|
|
||||||
h_output, input, sizeof(float) * inputSize, cudaMemcpyDeviceToHost
|
|
||||||
));
|
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
|
||||||
|
|
||||||
return h_output;
|
|
||||||
}
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
#include <stdexcept>
|
|
||||||
|
|
||||||
#include "input.hpp"
|
|
||||||
|
|
||||||
using namespace CUDANet::Layers;
|
|
||||||
|
|
||||||
Input::Input(int inputSize) : inputSize(inputSize) {
|
|
||||||
#ifdef USE_CUDA
|
|
||||||
initCUDA();
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
Input::~Input() {
|
|
||||||
#ifdef USE_CUDA
|
|
||||||
delCUDA();
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
float* Input::forwardCPU(const float* input) {
|
|
||||||
throw std::logic_error("Not implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
float* Input::forward(const float* input) {
|
|
||||||
#ifdef USE_CUDA
|
|
||||||
return forwardCUDA(input);
|
|
||||||
#else
|
|
||||||
return forwardCPU(input);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
int Input::get_output_size() {
|
|
||||||
return inputSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
int Input::getInputSize() {
|
|
||||||
return inputSize;
|
|
||||||
}
|
|
||||||
70
src/layers/max_pool.cpp
Normal file
70
src/layers/max_pool.cpp
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
#include "max_pool.hpp"
|
||||||
|
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
|
using namespace CUDANet::Layers;
|
||||||
|
|
||||||
|
MaxPool2d::MaxPool2d(
|
||||||
|
CUDANet::Shape input_shape,
|
||||||
|
CUDANet::Shape pooling_shape,
|
||||||
|
CUDANet::Shape stride_shape,
|
||||||
|
CUDANet::Shape padding_shape,
|
||||||
|
CUDANet::Backend* backend
|
||||||
|
)
|
||||||
|
: in_shape(input_shape),
|
||||||
|
pooling_shape(pooling_shape),
|
||||||
|
stride_shape(stride_shape),
|
||||||
|
padding_shape(padding_shape),
|
||||||
|
backend(backend) {
|
||||||
|
size_t out_h = (in_shape[0] + 2 * padding_shape[0] - pooling_shape[0]) /
|
||||||
|
stride_shape[0] +
|
||||||
|
1;
|
||||||
|
size_t out_w = (in_shape[1] + 2 * padding_shape[1] - pooling_shape[1]) /
|
||||||
|
stride_shape[1] +
|
||||||
|
1;
|
||||||
|
|
||||||
|
out_shape.resize(3);
|
||||||
|
out_shape[0] = out_h;
|
||||||
|
out_shape[1] = out_w;
|
||||||
|
out_shape[2] = in_shape[2];
|
||||||
|
|
||||||
|
output = CUDANet::Tensor(
|
||||||
|
Shape{out_shape[0] * out_shape[1] * out_shape[3]},
|
||||||
|
CUDANet::DType::FLOAT32, backend
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
MaxPool2d::~MaxPool2d() {}
|
||||||
|
|
||||||
|
CUDANet::Tensor& MaxPool2d::forward(CUDANet::Tensor& input) {
|
||||||
|
output.zero();
|
||||||
|
backend->maxPool2d(
|
||||||
|
input, output, in_shape, pooling_shape, stride_shape, padding_shape,
|
||||||
|
out_shape
|
||||||
|
);
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUDANet::Shape MaxPool2d::input_shape() {
|
||||||
|
return in_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUDANet::Shape MaxPool2d::output_shape() {
|
||||||
|
return out_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t MaxPool2d::input_size() {
|
||||||
|
return sizeof(float) * in_shape[0] * in_shape[1] * in_shape[2];
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t MaxPool2d::output_size() {
|
||||||
|
return sizeof(float) * out_shape[0] * out_shape[1] * out_shape[2];
|
||||||
|
}
|
||||||
|
|
||||||
|
void MaxPool2d::set_weights(void* input) {}
|
||||||
|
|
||||||
|
CUDANet::Tensor& MaxPool2d::get_weights() {}
|
||||||
|
|
||||||
|
void MaxPool2d::set_biases(void* input) {}
|
||||||
|
|
||||||
|
CUDANet::Tensor& MaxPool2d::get_biases() {}
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
#include "max_pooling.hpp"
|
|
||||||
#include <stdexcept>
|
|
||||||
|
|
||||||
using namespace CUDANet::Layers;
|
|
||||||
|
|
||||||
MaxPooling2d::MaxPooling2d(
|
|
||||||
shape2d inputSize,
|
|
||||||
int nChannels,
|
|
||||||
shape2d poolingSize,
|
|
||||||
shape2d stride,
|
|
||||||
shape2d padding,
|
|
||||||
ActivationType activationType
|
|
||||||
)
|
|
||||||
: inputSize(inputSize),
|
|
||||||
nChannels(nChannels),
|
|
||||||
poolingSize(poolingSize),
|
|
||||||
stride(stride),
|
|
||||||
padding(padding) {
|
|
||||||
outputSize = {
|
|
||||||
(inputSize.first + 2 * padding.first - poolingSize.first) /
|
|
||||||
stride.first +
|
|
||||||
1,
|
|
||||||
(inputSize.second + 2 * padding.second - poolingSize.second) /
|
|
||||||
stride.second +
|
|
||||||
1
|
|
||||||
};
|
|
||||||
|
|
||||||
activation = new Activation(
|
|
||||||
activationType, outputSize.first * outputSize.second * nChannels
|
|
||||||
);
|
|
||||||
|
|
||||||
#ifdef USE_CUDA
|
|
||||||
initCUDA();
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
MaxPooling2d::~MaxPooling2d() {
|
|
||||||
#ifdef USE_CUDA
|
|
||||||
delCUDA();
|
|
||||||
#endif
|
|
||||||
delete activation;
|
|
||||||
}
|
|
||||||
|
|
||||||
float* MaxPooling2d::forwardCPU(const float* input) {
|
|
||||||
throw std::logic_error("Not implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
float* MaxPooling2d::forward(const float* input) {
|
|
||||||
#ifdef USE_CUDA
|
|
||||||
return forwardCUDA(input);
|
|
||||||
#else
|
|
||||||
return forwardCPU(input);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
int MaxPooling2d::get_output_size() {
|
|
||||||
return outputSize.first * outputSize.second * nChannels;
|
|
||||||
}
|
|
||||||
|
|
||||||
int MaxPooling2d::getInputSize() {
|
|
||||||
return inputSize.first * inputSize.second * nChannels;
|
|
||||||
}
|
|
||||||
|
|
||||||
shape2d MaxPooling2d::getOutputDims() {
|
|
||||||
return outputSize;
|
|
||||||
}
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
#include "output.hpp"
|
|
||||||
#include <stdexcept>
|
|
||||||
|
|
||||||
using namespace CUDANet::Layers;
|
|
||||||
|
|
||||||
|
|
||||||
Output::Output(int inputSize) : inputSize(inputSize) {
|
|
||||||
h_output = (float*) malloc(sizeof(float) * inputSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
Output::~Output() {
|
|
||||||
free(h_output);
|
|
||||||
}
|
|
||||||
|
|
||||||
float* Output::forwardCPU(const float* input) {
|
|
||||||
throw std::logic_error("Not implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
float* Output::forward(const float* input) {
|
|
||||||
#ifdef USE_CUDA
|
|
||||||
return forwardCUDA(input);
|
|
||||||
#else
|
|
||||||
return forwardCPU(input);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
int Output::get_output_size() {
|
|
||||||
return inputSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
int Output::getInputSize() {
|
|
||||||
return inputSize;
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user