mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-22 14:24:22 +00:00
Compare commits
3 Commits
10c84d75fc
...
e4d05931d4
| Author | SHA1 | Date | |
|---|---|---|---|
| e4d05931d4 | |||
| 7896ff0e24 | |||
| dfdfa19022 |
@@ -40,6 +40,28 @@ class Backend {
|
||||
const size_t input_size,
|
||||
const size_t output_size
|
||||
) = 0;
|
||||
|
||||
virtual CUDANet::Tensor& conv2d(
|
||||
const CUDANet::Tensor& weights,
|
||||
const CUDANet::Tensor& biases,
|
||||
const CUDANet::Tensor& input,
|
||||
CUDANet::Tensor& output,
|
||||
const CUDANet::Shape in_shape,
|
||||
const CUDANet::Shape padding_shape,
|
||||
const CUDANet::Shape kernel_shape,
|
||||
const CUDANet::Shape stride_shape,
|
||||
const CUDANet::Shape out_shape
|
||||
) = 0;
|
||||
|
||||
virtual CUDANet::Tensor& maxPool2d(
|
||||
const CUDANet::Tensor& input,
|
||||
CUDANet::Tensor& output,
|
||||
CUDANet::Shape input_shape,
|
||||
CUDANet::Shape pool_shape,
|
||||
CUDANet::Shape stride_shape,
|
||||
CUDANet::Shape padding_shape,
|
||||
CUDANet::Shape output_shape
|
||||
) = 0;
|
||||
};
|
||||
|
||||
} // namespace CUDANet
|
||||
@@ -36,6 +36,28 @@ class CUDA : public Backend {
|
||||
const size_t input_size,
|
||||
const size_t output_size
|
||||
) override;
|
||||
|
||||
CUDANet::Tensor& conv2d(
|
||||
const CUDANet::Tensor& weights,
|
||||
const CUDANet::Tensor& biases,
|
||||
const CUDANet::Tensor& input,
|
||||
CUDANet::Tensor& output,
|
||||
const CUDANet::Shape in_shape,
|
||||
const CUDANet::Shape padding_shape,
|
||||
const CUDANet::Shape kernel_shape,
|
||||
const CUDANet::Shape stride_shape,
|
||||
const CUDANet::Shape out_shape
|
||||
) override;
|
||||
|
||||
CUDANet::Tensor& CUDA::maxPool2d(
|
||||
const CUDANet::Tensor& input,
|
||||
CUDANet::Tensor& output,
|
||||
CUDANet::Shape input_shape,
|
||||
CUDANet::Shape pool_shape,
|
||||
CUDANet::Shape stride_shape,
|
||||
CUDANet::Shape padding_shape,
|
||||
CUDANet::Shape output_shape
|
||||
) override;
|
||||
};
|
||||
|
||||
} // namespace CUDANet::Backend
|
||||
@@ -1,39 +1,20 @@
|
||||
#ifndef CUDANET_CONVOLUTION_H
|
||||
#define CUDANET_CONVOLUTION_H
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include "layer.hpp"
|
||||
|
||||
namespace CUDANet::Kernels {
|
||||
|
||||
/**
|
||||
* @brief Convolution kernel
|
||||
*
|
||||
* @param d_input Device pointer to the input matrix
|
||||
* @param d_kernel Device pointer to the convolution kernel
|
||||
* @param d_bias Device pointer to the bias
|
||||
* @param d_output Device pointer to the output matrix
|
||||
* @param inputSize Width and height of the input matrix
|
||||
* @param nChannels Number of channels in the input matrix
|
||||
* @param kernelSize Width and height of the convolution kernel
|
||||
* @param stride Convolution stride
|
||||
* @param nFilters Number of output filters
|
||||
* @param outputSize Width and height of the output matrix
|
||||
*/
|
||||
__global__ void convolution(
|
||||
const float* __restrict__ d_input,
|
||||
const float* __restrict__ d_kernel,
|
||||
const float* __restrict__ d_bias,
|
||||
float* __restrict__ d_output,
|
||||
const shape2d inputSize,
|
||||
const int nChannels,
|
||||
const shape2d paddingSize,
|
||||
const shape2d kernelSize,
|
||||
const shape2d stride,
|
||||
const int nFilters,
|
||||
const shape2d outputSize
|
||||
const Shape input_shape,
|
||||
const Shape padding_shape,
|
||||
const Shape kernel_shape,
|
||||
const Shape stride_shape,
|
||||
const Shape output_shape
|
||||
);
|
||||
|
||||
} // namespace CUDANet::Kernels
|
||||
|
||||
#endif // CUDANET_CONVOLUTION_H
|
||||
@@ -1,33 +1,28 @@
|
||||
#ifndef CUDANET_POOLING_H
|
||||
#define CUDANET_POOLING_H
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include "layer.hpp"
|
||||
|
||||
namespace CUDANet::Kernels {
|
||||
|
||||
__global__ void max_pooling(
|
||||
__global__ void max_pool(
|
||||
const float* __restrict__ d_input,
|
||||
float* __restrict__ d_output,
|
||||
const shape2d inputSize,
|
||||
const shape2d outputSize,
|
||||
const int nChannels,
|
||||
const shape2d poolingSize,
|
||||
const shape2d stride,
|
||||
const shape2d padding
|
||||
const Shape input_shape,
|
||||
const Shape output_shape,
|
||||
const Shape pool_shape,
|
||||
const Shape stride_shape,
|
||||
const Shape padding_shape
|
||||
);
|
||||
|
||||
__global__ void avg_pooling(
|
||||
__global__ void avg_pool(
|
||||
const float* __restrict__ d_input,
|
||||
float* __restrict__ d_output,
|
||||
const shape2d inputSize,
|
||||
const shape2d outputSize,
|
||||
const int nChannels,
|
||||
const shape2d poolingSize,
|
||||
const shape2d stride,
|
||||
const shape2d padding
|
||||
const Shape input_shape,
|
||||
const Shape output_shape,
|
||||
const Shape pool_shape,
|
||||
const Shape stride_shape,
|
||||
const Shape padding_shape
|
||||
);
|
||||
|
||||
} // namespace CUDANet::Kernels
|
||||
|
||||
#endif // CUDANET_POOLING_H
|
||||
@@ -20,7 +20,7 @@ class Layer {
|
||||
|
||||
virtual ~Layer(){};
|
||||
|
||||
virtual CUDANet::Tensor& forward(const CUDANet::Tensor &input) = 0;
|
||||
virtual CUDANet::Tensor& forward(CUDANet::Tensor &input) = 0;
|
||||
|
||||
virtual CUDANet::Shape input_shape() = 0;
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#ifndef CUDANET_CONV_LAYER_H
|
||||
#define CUDANET_CONV_LAYER_H
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
@@ -12,149 +11,52 @@ namespace CUDANet::Layers {
|
||||
* @brief 2D convolutional layer
|
||||
*
|
||||
*/
|
||||
class Conv2d : public WeightedLayer, public TwoDLayer {
|
||||
class Conv2d : public Layer {
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new Conv 2d layer
|
||||
*
|
||||
* @param inputSize Width and height of the input matrix
|
||||
* @param inputChannels Number of channels in the input matrix
|
||||
* @param kernelSize Width and height of the convolution kernel
|
||||
* @param stride Convolution stride
|
||||
* @param numFilters Number of output filters
|
||||
* @param paddingSize Padding size
|
||||
* @param activationType Activation function type ('RELU', 'SIGMOID',
|
||||
* 'SOFTMAX' or 'NONE')
|
||||
*/
|
||||
Conv2d(
|
||||
shape2d inputSize,
|
||||
int inputChannels,
|
||||
shape2d kernelSize,
|
||||
shape2d stride,
|
||||
int numFilters,
|
||||
shape2d paddingSize,
|
||||
ActivationType activationType
|
||||
CUDANet::Shape input_shape,
|
||||
CUDANet::Shape kernel_shape,
|
||||
CUDANet::Shape stride_shape,
|
||||
CUDANet::Shape padding_shape,
|
||||
CUDANet::Backend* backend
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief Destroy the Conv 2d object
|
||||
*
|
||||
*/
|
||||
~Conv2d();
|
||||
~Conv2d() {};
|
||||
|
||||
/**
|
||||
* @brief Forward pass of the convolutional layer
|
||||
*
|
||||
* @param d_input Device pointer to the input matrix
|
||||
* @return Device pointer to the output matrix
|
||||
*/
|
||||
float* forward(const float* d_input);
|
||||
CUDANet::Tensor& forward(CUDANet::Tensor& input) override;
|
||||
|
||||
/**
|
||||
* @brief Set the weights of the convolutional layer
|
||||
*
|
||||
* @param weights_input Pointer to the weights
|
||||
*/
|
||||
void setWeights(const float* weights_input);
|
||||
CUDANet::Shape input_shape() override;
|
||||
|
||||
/**
|
||||
* @brief Get the weights of the convolutional layer
|
||||
*
|
||||
* @return std::vector<float>
|
||||
*/
|
||||
std::vector<float> getWeights();
|
||||
CUDANet::Shape output_shape() override;
|
||||
|
||||
/**
|
||||
* @brief Set the biases of the convolutional layer
|
||||
*
|
||||
* @param biases_input Pointer to the biases
|
||||
*/
|
||||
void setBiases(const float* biases_input);
|
||||
size_t input_size() override;
|
||||
|
||||
/**
|
||||
* @brief Get the biases of the convolutional layer
|
||||
*
|
||||
* @return std::vector<float>
|
||||
*/
|
||||
std::vector<float> getBiases();
|
||||
size_t output_size();
|
||||
|
||||
/**
|
||||
* @brief Get output size
|
||||
*
|
||||
* @return int output size
|
||||
*/
|
||||
int getOutputSize();
|
||||
void set_weights(void* input) override;
|
||||
|
||||
/**
|
||||
* @brief Get input size
|
||||
*
|
||||
* @return int input size
|
||||
*/
|
||||
int getInputSize();
|
||||
CUDANet::Tensor& get_weights() override;
|
||||
|
||||
/**
|
||||
* @brief Get the padding size of the layer
|
||||
*
|
||||
* @return int
|
||||
*/
|
||||
shape2d getPaddingSize() {
|
||||
return paddingSize;
|
||||
}
|
||||
void set_biases(void* input) override;
|
||||
|
||||
shape2d getOutputDims();
|
||||
CUDANet::Tensor& get_biases() override;
|
||||
|
||||
CUDANet::Shape get_padding_shape();
|
||||
|
||||
private:
|
||||
// Inputs
|
||||
shape2d inputSize;
|
||||
int inputChannels;
|
||||
CUDANet::Backend* backend;
|
||||
|
||||
// Outputs
|
||||
shape2d outputSize;
|
||||
CUDANet::Shape in_shape;
|
||||
CUDANet::Shape out_shape;
|
||||
|
||||
// Kernel
|
||||
shape2d kernelSize;
|
||||
shape2d stride;
|
||||
shape2d paddingSize;
|
||||
int numFilters;
|
||||
CUDANet::Shape kernel_shape;
|
||||
CUDANet::Shape stride_shape;
|
||||
CUDANet::Shape padding_shape;
|
||||
|
||||
// Kernels
|
||||
std::vector<float> weights;
|
||||
std::vector<float> biases;
|
||||
CUDANet::Tensor weights;
|
||||
CUDANet::Tensor biases;
|
||||
|
||||
float* forwardCPU(const float* input);
|
||||
|
||||
// Cuda
|
||||
#ifdef USE_CUDA
|
||||
float* d_output;
|
||||
float* d_weights;
|
||||
float* d_biases;
|
||||
|
||||
float* forwardCUDA(const float* d_input);
|
||||
void initCUDA();
|
||||
void delCUDA();
|
||||
|
||||
/**
|
||||
* @brief Copy weights and biases to the device
|
||||
*
|
||||
*/
|
||||
void toCuda();
|
||||
#endif
|
||||
|
||||
Activation* activation;
|
||||
|
||||
/**
|
||||
* @brief Initialize weights of the convolutional layer with zeros
|
||||
*
|
||||
*/
|
||||
void initializeWeights();
|
||||
|
||||
/**
|
||||
* @brief Initialize biases of the convolutional layer with zeros
|
||||
*
|
||||
*/
|
||||
void initializeBiases();
|
||||
CUDANet::Tensor output;
|
||||
};
|
||||
|
||||
} // namespace CUDANet::Layers
|
||||
|
||||
#endif // CUDANET_CONV_LAYER_H
|
||||
|
||||
@@ -14,11 +14,11 @@ namespace CUDANet::Layers {
|
||||
class Dense : public Layer {
|
||||
public:
|
||||
|
||||
Dense(CUDANet::Backend *backend, CUDANet::Shape input_shape, CUDANet::Shape output_shape);
|
||||
Dense(CUDANet::Shape input_shape, CUDANet::Shape output_shape, CUDANet::Backend *backend);
|
||||
|
||||
~Dense();
|
||||
|
||||
CUDANet::Tensor& forward(const CUDANet::Tensor &input) override;
|
||||
CUDANet::Tensor& forward(CUDANet::Tensor &input) override;
|
||||
|
||||
CUDANet::Shape input_shape() override;
|
||||
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
#ifndef CUDANET_INPUT_LAYER_H
|
||||
#define CUDANET_INPUT_LAYER_H
|
||||
|
||||
#include "layer.hpp"
|
||||
|
||||
namespace CUDANet::Layers {
|
||||
|
||||
/**
|
||||
* @brief Input layer, just copies the input to the device
|
||||
*
|
||||
*/
|
||||
class Input : public Layer {
|
||||
public:
|
||||
/**
|
||||
* @brief Create a new Input layer
|
||||
*
|
||||
* @param inputSize Size of the input vector
|
||||
*/
|
||||
explicit Input(int inputSize);
|
||||
|
||||
/**
|
||||
* @brief Destroy the Input layer
|
||||
*
|
||||
*/
|
||||
~Input();
|
||||
|
||||
/**
|
||||
* @brief Forward pass of the input layer. Just copies the input to the
|
||||
* device
|
||||
*
|
||||
* @param input Host pointer to the input vector
|
||||
* @return Device pointer to the output vector
|
||||
*/
|
||||
float* forward(const float* input);
|
||||
|
||||
/**
|
||||
* @brief Get output size
|
||||
*
|
||||
* @return int output size
|
||||
*/
|
||||
int get_output_size();
|
||||
|
||||
/**
|
||||
* @brief Get input size
|
||||
*
|
||||
* @return int input size
|
||||
*/
|
||||
int getInputSize();
|
||||
|
||||
private:
|
||||
int inputSize;
|
||||
|
||||
float* forwardCPU(const float* input);
|
||||
|
||||
#ifdef USE_CUDA
|
||||
float* d_output;
|
||||
|
||||
float* forwardCUDA(const float* input);
|
||||
void initCUDA();
|
||||
void delCUDA();
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace CUDANet::Layers
|
||||
|
||||
#endif // CUDANET_INPUT_LAYER_H
|
||||
51
include/layers/max_pool.hpp
Normal file
51
include/layers/max_pool.hpp
Normal file
@@ -0,0 +1,51 @@
|
||||
#pragma once
|
||||
|
||||
#include "layer.hpp"
|
||||
|
||||
namespace CUDANet::Layers {
|
||||
|
||||
class MaxPool2d : public Layer {
|
||||
public:
|
||||
MaxPool2d(
|
||||
CUDANet::Shape input_shape,
|
||||
CUDANet::Shape pooling_shape,
|
||||
CUDANet::Shape stride_shape,
|
||||
CUDANet::Shape padding_shape,
|
||||
CUDANet::Backend* backend
|
||||
);
|
||||
~MaxPool2d();
|
||||
|
||||
CUDANet::Tensor& forward(CUDANet::Tensor &input) override;
|
||||
|
||||
CUDANet::Shape input_shape() override;
|
||||
|
||||
CUDANet::Shape output_shape() override;
|
||||
|
||||
size_t input_size() override;
|
||||
|
||||
size_t output_size() override;
|
||||
|
||||
void set_weights(void *input) override;
|
||||
|
||||
CUDANet::Tensor& get_weights() override;
|
||||
|
||||
void set_biases(void *input) override;
|
||||
|
||||
CUDANet::Tensor& get_biases() override;
|
||||
|
||||
|
||||
|
||||
private:
|
||||
CUDANet::Shape in_shape;
|
||||
|
||||
CUDANet::Shape pooling_shape;
|
||||
CUDANet::Shape stride_shape;
|
||||
CUDANet::Shape padding_shape;
|
||||
|
||||
CUDANet::Shape out_shape;
|
||||
CUDANet::Tensor output;
|
||||
|
||||
CUDANet::Backend *backend;
|
||||
};
|
||||
|
||||
} // namespace CUDANet::Layers
|
||||
@@ -1,63 +0,0 @@
|
||||
#ifndef CUDANET_MAX_POOLING_H
|
||||
#define CUDANET_MAX_POOLING_H
|
||||
|
||||
#include "activation.hpp"
|
||||
#include "layer.hpp"
|
||||
|
||||
namespace CUDANet::Layers {
|
||||
|
||||
class MaxPooling2d : public Layer, public TwoDLayer {
|
||||
public:
|
||||
MaxPooling2d(
|
||||
shape2d inputSize,
|
||||
int nChannels,
|
||||
shape2d poolingSize,
|
||||
shape2d stride,
|
||||
shape2d padding,
|
||||
ActivationType activationType
|
||||
);
|
||||
~MaxPooling2d();
|
||||
|
||||
float* forward(const float* input);
|
||||
|
||||
/**
|
||||
* @brief Get output size
|
||||
*
|
||||
* @return int output size
|
||||
*/
|
||||
int get_output_size();
|
||||
|
||||
/**
|
||||
* @brief Get input size
|
||||
*
|
||||
* @return int input size
|
||||
*/
|
||||
int getInputSize();
|
||||
|
||||
shape2d getOutputDims();
|
||||
|
||||
private:
|
||||
shape2d inputSize;
|
||||
int nChannels;
|
||||
shape2d poolingSize;
|
||||
shape2d stride;
|
||||
shape2d padding;
|
||||
|
||||
shape2d outputSize;
|
||||
|
||||
Activation* activation;
|
||||
|
||||
float* forwardCPU(const float* input);
|
||||
|
||||
#ifdef USE_CUDA
|
||||
float* d_output;
|
||||
float* forwardCUDA(const float* d_input);
|
||||
|
||||
void initCUDA();
|
||||
void delCUDA();
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace CUDANet::Layers
|
||||
|
||||
#endif // CUDANET_MAX_POOLING_H
|
||||
@@ -1,59 +0,0 @@
|
||||
#ifndef CUDANET_OUTPUT_LAYER_H
|
||||
#define CUDANET_OUTPUT_LAYER_H
|
||||
|
||||
#include "layer.hpp"
|
||||
|
||||
namespace CUDANet::Layers {
|
||||
|
||||
class Output : public Layer {
|
||||
public:
|
||||
/**
|
||||
* @brief Create a new Output layer
|
||||
*
|
||||
* @param inputSize Size of the input vector
|
||||
*/
|
||||
explicit Output(int inputSize);
|
||||
|
||||
/**
|
||||
* @brief Destroy the Output layer
|
||||
*
|
||||
*/
|
||||
~Output();
|
||||
|
||||
/**
|
||||
* @brief Forward pass of the output layer. Just copies the input from
|
||||
* device to host
|
||||
*
|
||||
* @param input Device pointer to the input vector
|
||||
* @return Host pointer to the output vector
|
||||
*/
|
||||
float* forward(const float* input);
|
||||
|
||||
/**
|
||||
* @brief Get output size
|
||||
*
|
||||
* @return int output size
|
||||
*/
|
||||
int get_output_size();
|
||||
|
||||
/**
|
||||
* @brief Get input size
|
||||
*
|
||||
* @return int input size
|
||||
*/
|
||||
int getInputSize();
|
||||
|
||||
private:
|
||||
int inputSize;
|
||||
float* h_output;
|
||||
|
||||
float* forwardCPU(const float* input);
|
||||
|
||||
#ifdef USE_CUDA
|
||||
float* forwardCUDA(const float* input);
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace CUDANet::Layers
|
||||
|
||||
#endif // CUDANET_OUTPUT_LAYER_H
|
||||
@@ -9,52 +9,50 @@ __global__ void Kernels::convolution(
|
||||
const float* __restrict__ d_kernel,
|
||||
const float* __restrict__ d_bias,
|
||||
float* __restrict__ d_output,
|
||||
const shape2d inputSize,
|
||||
const int nChannels,
|
||||
const shape2d paddingSize,
|
||||
const shape2d kernelSize,
|
||||
const shape2d stride,
|
||||
const int nFilters,
|
||||
const shape2d outputSize
|
||||
const Shape input_shape,
|
||||
const Shape padding_shape,
|
||||
const Shape kernel_shape,
|
||||
const Shape stride_shape,
|
||||
const Shape output_shape
|
||||
) {
|
||||
int j = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
int f = blockDim.z * blockIdx.z + threadIdx.z;
|
||||
|
||||
if (i >= outputSize.first || j >= outputSize.second || f >= nFilters) {
|
||||
if (i >= output_shape[0] || j >= output_shape[1] || f >= output_shape[2]) {
|
||||
return;
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
// Iterate over kernel and input matrix
|
||||
for (int c = 0; c < nChannels; c++) {
|
||||
for (int k = 0; k < kernelSize.first; k++) {
|
||||
for (int l = 0; l < kernelSize.second; l++) {
|
||||
for (int c = 0; c < input_shape[2]; c++) {
|
||||
for (int k = 0; k < kernel_shape[0]; k++) {
|
||||
for (int l = 0; l < kernel_shape[1]; l++) {
|
||||
// if i, j is in the padding region
|
||||
if (i * stride.first + k < paddingSize.first ||
|
||||
i * stride.first + k >=
|
||||
(inputSize.first + paddingSize.first) ||
|
||||
j * stride.second + l < paddingSize.second ||
|
||||
j * stride.second + l >=
|
||||
(inputSize.second + paddingSize.second)) {
|
||||
if (i * stride_shape[0] + k < padding_shape[0] ||
|
||||
i * stride_shape[0] + k >=
|
||||
(input_shape[0] + padding_shape[0]) ||
|
||||
j * stride_shape[1] + l < padding_shape[1] ||
|
||||
j * stride_shape[1] + l >=
|
||||
(input_shape[1] + padding_shape[1])) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int kernelIndex =
|
||||
f * kernelSize.first * kernelSize.second * nChannels +
|
||||
c * kernelSize.first * kernelSize.second +
|
||||
k * kernelSize.second + l;
|
||||
int inputIndex = c * inputSize.first * inputSize.second +
|
||||
(i * stride.first + k - paddingSize.first) *
|
||||
inputSize.second +
|
||||
(j * stride.second + l - paddingSize.second);
|
||||
f * kernel_shape[0] * kernel_shape[1] * input_shape[2] +
|
||||
c * kernel_shape[0] * kernel_shape[1] +
|
||||
k * kernel_shape[1] + l;
|
||||
int inputIndex = c * input_shape[0] * input_shape[1] +
|
||||
(i * stride_shape[0] + k - padding_shape[0]) *
|
||||
input_shape[1] +
|
||||
(j * stride_shape[1] + l - padding_shape[1]);
|
||||
|
||||
sum += d_kernel[kernelIndex] * d_input[inputIndex];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d_output[f * outputSize.first * outputSize.second + i * outputSize.second + j] =
|
||||
d_output[f * output_shape[0] * output_shape[1] + i * output_shape[1] + j] =
|
||||
sum + d_bias[f];
|
||||
}
|
||||
@@ -4,35 +4,34 @@
|
||||
|
||||
using namespace CUDANet;
|
||||
|
||||
__global__ void Kernels::max_pooling(
|
||||
__global__ void Kernels::max_pool(
|
||||
const float* __restrict__ d_input,
|
||||
float* __restrict__ d_output,
|
||||
const shape2d inputSize,
|
||||
const shape2d outputSize,
|
||||
const int nChannels,
|
||||
const shape2d poolingSize,
|
||||
const shape2d stride,
|
||||
const shape2d padding
|
||||
const Shape input_shape,
|
||||
const Shape output_shape,
|
||||
const Shape pool_shape,
|
||||
const Shape stride_shape,
|
||||
const Shape padding_shape
|
||||
) {
|
||||
int j = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
int c = blockDim.z * blockIdx.z + threadIdx.z;
|
||||
|
||||
if (i >= outputSize.first || j >= outputSize.second || c >= nChannels) {
|
||||
if (i >= output_shape[0] || j >= output_shape[1] || c >= output_shape[2]) {
|
||||
return;
|
||||
}
|
||||
|
||||
float max = 0.0f;
|
||||
|
||||
for (int k = 0; k < poolingSize.first; k++) {
|
||||
for (int l = 0; l < poolingSize.second; l++) {
|
||||
int inputRow = i * stride.first + k - padding.first;
|
||||
int inputCol = j * stride.second + l - padding.second;
|
||||
for (int k = 0; k < pool_shape[0]; k++) {
|
||||
for (int l = 0; l < pool_shape[1]; l++) {
|
||||
int inputRow = i * stride_shape[0] + k - padding_shape[0];
|
||||
int inputCol = j * stride_shape[1] + l - padding_shape[1];
|
||||
|
||||
if (inputRow >= 0 && inputRow < inputSize.first && inputCol >= 0 &&
|
||||
inputCol < inputSize.second) {
|
||||
int inputIndex = c * inputSize.first * inputSize.second +
|
||||
inputRow * inputSize.second + inputCol;
|
||||
if (inputRow >= 0 && inputRow < input_shape[0] && inputCol >= 0 &&
|
||||
inputCol < input_shape[1]) {
|
||||
int inputIndex = c * input_shape[0] * input_shape[1] +
|
||||
inputRow * input_shape[1] + inputCol;
|
||||
if (d_input[inputIndex] > max) {
|
||||
max = d_input[inputIndex];
|
||||
}
|
||||
@@ -41,45 +40,44 @@ __global__ void Kernels::max_pooling(
|
||||
}
|
||||
|
||||
d_output
|
||||
[c * outputSize.first * outputSize.second + i * outputSize.second + j] =
|
||||
[c * output_shape[0] * output_shape[1] + i * output_shape[1] + j] =
|
||||
max;
|
||||
}
|
||||
|
||||
__global__ void Kernels::avg_pooling(
|
||||
__global__ void Kernels::avg_pool(
|
||||
const float* __restrict__ d_input,
|
||||
float* __restrict__ d_output,
|
||||
const shape2d inputSize,
|
||||
const shape2d outputSize,
|
||||
const int nChannels,
|
||||
const shape2d poolingSize,
|
||||
const shape2d stride,
|
||||
const shape2d padding
|
||||
const Shape input_shape,
|
||||
const Shape output_shape,
|
||||
const Shape pool_shape,
|
||||
const Shape stride_shape,
|
||||
const Shape padding_shape
|
||||
) {
|
||||
int j = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
int c = blockDim.z * blockIdx.z + threadIdx.z;
|
||||
|
||||
if (i >= outputSize.first || j >= outputSize.second || c >= nChannels) {
|
||||
if (i >= output_shape[0] || j >= output_shape[1] || c >= output_shape[2]) {
|
||||
return;
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int k = 0; k < poolingSize.first; k++) {
|
||||
for (int l = 0; l < poolingSize.second; l++) {
|
||||
int inputRow = i * stride.first + k - padding.first;
|
||||
int inputCol = j * stride.second + l - padding.second;
|
||||
for (int k = 0; k < pool_shape[0]; k++) {
|
||||
for (int l = 0; l < pool_shape[1]; l++) {
|
||||
int inputRow = i * stride_shape[0] + k - padding_shape[0];
|
||||
int inputCol = j * stride_shape[1] + l - padding_shape[1];
|
||||
|
||||
if (inputRow >= 0 && inputRow < inputSize.first && inputCol >= 0 &&
|
||||
inputCol < inputSize.second) {
|
||||
int inputIndex = c * inputSize.first * inputSize.second +
|
||||
inputRow * inputSize.second + inputCol;
|
||||
if (inputRow >= 0 && inputRow < input_shape[0] && inputCol >= 0 &&
|
||||
inputCol < input_shape[1]) {
|
||||
int inputIndex = c * input_shape[0] * input_shape[1] +
|
||||
inputRow * input_shape[1] + inputCol;
|
||||
sum += d_input[inputIndex];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d_output
|
||||
[c * outputSize.first * outputSize.second + i * outputSize.second + j] =
|
||||
sum / (poolingSize.first * poolingSize.second);
|
||||
[c * output_shape[0] * output_shape[1] + i * output_shape[1] + j] =
|
||||
sum / (pool_shape[0] * pool_shape[1]);
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
#include "backend/cuda.cuh"
|
||||
#include "kernels/activation_functions.cuh"
|
||||
#include "kernels/convolution.cuh"
|
||||
#include "kernels/matmul.cuh"
|
||||
#include "kernels/pooling.cuh"
|
||||
#include "utils/cuda_helper.cuh"
|
||||
|
||||
using namespace CUDANet::Backend;
|
||||
@@ -57,7 +59,7 @@ CUDANet::Tensor& CUDA::dense(
|
||||
const CUDANet::Tensor& weights,
|
||||
const CUDANet::Tensor& biases,
|
||||
const CUDANet::Tensor& input,
|
||||
CUDANet::Tensor& output,
|
||||
CUDANet::Tensor& output,
|
||||
const size_t input_size,
|
||||
const size_t output_size
|
||||
) {
|
||||
@@ -78,5 +80,60 @@ CUDANet::Tensor& CUDA::dense(
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
CUDANet::Tensor& CUDA::conv2d(
|
||||
const CUDANet::Tensor& weights,
|
||||
const CUDANet::Tensor& biases,
|
||||
const CUDANet::Tensor& input,
|
||||
CUDANet::Tensor& output,
|
||||
const CUDANet::Shape in_shape,
|
||||
const CUDANet::Shape padding_shape,
|
||||
const CUDANet::Shape kernel_shape,
|
||||
const CUDANet::Shape stride_shape,
|
||||
const CUDANet::Shape out_shape
|
||||
) {
|
||||
dim3 block(8, 8, 8);
|
||||
dim3 grid(
|
||||
(out_shape[0] + block.x - 1) / block.x,
|
||||
(out_shape[1] + block.y - 1) / block.y,
|
||||
(out_shape[3] + block.z - 1) / block.z
|
||||
);
|
||||
|
||||
Kernels::convolution<<<grid, block>>>(
|
||||
input.data<float>(), weights.data<float>(), biases.data<float>(),
|
||||
output.data<float>(), in_shape, padding_shape, kernel_shape,
|
||||
stride_shape, out_shape
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
CUDANet::Tensor& CUDA::maxPool2d(
|
||||
const CUDANet::Tensor& input,
|
||||
CUDANet::Tensor& output,
|
||||
CUDANet::Shape input_shape,
|
||||
CUDANet::Shape pool_shape,
|
||||
CUDANet::Shape stride_shape,
|
||||
CUDANet::Shape padding_shape,
|
||||
CUDANet::Shape output_shape
|
||||
) {
|
||||
dim3 block(8, 8, 8);
|
||||
dim3 grid(
|
||||
(output_shape[0] + block.x - 1) / block.x,
|
||||
(output_shape[1] + block.y - 1) / block.y,
|
||||
(output_shape[2] + block.z - 1) / block.z
|
||||
);
|
||||
|
||||
Kernels::max_pool<<<grid, block>>>(
|
||||
input.data<float>(), output.data<float>(), input_shape, output_shape, pool_shape,
|
||||
stride_shape, padding_shape
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return output;
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
#include <vector>
|
||||
|
||||
#include "activation.hpp"
|
||||
#include "conv2d.hpp"
|
||||
#include "convolution.cuh"
|
||||
#include "cuda_helper.cuh"
|
||||
#include "layer.hpp"
|
||||
#include "matmul.cuh"
|
||||
#include "vector.cuh"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
void Conv2d::initCUDA() {
|
||||
d_output = nullptr;
|
||||
CUDA_CHECK(cudaMalloc(
|
||||
(void**)&d_output,
|
||||
sizeof(float) * outputSize.first * outputSize.second * numFilters
|
||||
));
|
||||
|
||||
d_weights = nullptr;
|
||||
CUDA_CHECK(cudaMalloc(
|
||||
(void**)&d_weights, sizeof(float) * kernelSize.first *
|
||||
kernelSize.second * inputChannels * numFilters
|
||||
));
|
||||
|
||||
d_biases = nullptr;
|
||||
CUDA_CHECK(cudaMalloc((void**)&d_biases, sizeof(float) * numFilters));
|
||||
}
|
||||
|
||||
void Conv2d::delCUDA() {
|
||||
cudaFree(d_output);
|
||||
cudaFree(d_weights);
|
||||
cudaFree(d_biases);
|
||||
}
|
||||
|
||||
void Conv2d::toCuda() {
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_weights, weights.data(),
|
||||
sizeof(float) * kernelSize.first * kernelSize.second * inputChannels *
|
||||
numFilters,
|
||||
cudaMemcpyHostToDevice
|
||||
));
|
||||
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_biases, biases.data(), sizeof(float) * numFilters,
|
||||
cudaMemcpyHostToDevice
|
||||
));
|
||||
}
|
||||
|
||||
float* Conv2d::forwardCUDA(const float* d_input) {
|
||||
// Convolve
|
||||
dim3 block(8, 8, 8);
|
||||
dim3 grid(
|
||||
(outputSize.first + block.x - 1) / block.x,
|
||||
(outputSize.second + block.y - 1) / block.y,
|
||||
(numFilters + block.z - 1) / block.z
|
||||
);
|
||||
|
||||
CUDANet::Utils::clear(d_output, outputSize.first * outputSize.second * numFilters);
|
||||
|
||||
Kernels::convolution<<<grid, block>>>(
|
||||
d_input, d_weights, d_biases, d_output, inputSize, inputChannels,
|
||||
paddingSize, kernelSize, stride, numFilters, outputSize
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Apply activation
|
||||
activation->activate(d_output);
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return d_output;
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
#include "cuda_helper.cuh"
|
||||
#include "input.hpp"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
void Input::initCUDA() {
|
||||
d_output = nullptr;
|
||||
CUDA_CHECK(cudaMalloc((void**)&d_output, sizeof(float) * inputSize));
|
||||
}
|
||||
|
||||
void Input::delCUDA() {
|
||||
cudaFree(d_output);
|
||||
}
|
||||
|
||||
float* Input::forwardCUDA(const float* input) {
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_output, input, sizeof(float) * inputSize, cudaMemcpyHostToDevice
|
||||
));
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return d_output;
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
#include "cuda_helper.cuh"
|
||||
#include "max_pooling.hpp"
|
||||
#include "pooling.cuh"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
void MaxPooling2d::initCUDA() {
|
||||
d_output = nullptr;
|
||||
CUDA_CHECK(cudaMalloc(
|
||||
(void**)&d_output,
|
||||
sizeof(float) * outputSize.first * outputSize.second * nChannels
|
||||
));
|
||||
}
|
||||
|
||||
void MaxPooling2d::delCUDA() {
|
||||
cudaFree(d_output);
|
||||
}
|
||||
|
||||
|
||||
float* MaxPooling2d::forwardCUDA(const float* d_input) {
|
||||
dim3 block(8, 8, 8);
|
||||
dim3 grid(
|
||||
(outputSize.first + block.x - 1) / block.x,
|
||||
(outputSize.second + block.y - 1) / block.y,
|
||||
(nChannels + block.z - 1) / block.z
|
||||
);
|
||||
|
||||
Kernels::max_pooling<<<grid, block>>>(
|
||||
d_input, d_output, inputSize, outputSize, nChannels, poolingSize,
|
||||
stride, padding
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
activation->activate(d_output);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return d_output;
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
#include "output.hpp"
|
||||
|
||||
#include "cuda_helper.cuh"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
float* Output::forwardCUDA(const float* input) {
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
h_output, input, sizeof(float) * inputSize, cudaMemcpyDeviceToHost
|
||||
));
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return h_output;
|
||||
}
|
||||
@@ -1,111 +1,136 @@
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
#include "activation.hpp"
|
||||
#include "conv2d.hpp"
|
||||
|
||||
#include <format>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "layer.hpp"
|
||||
#include "tensor.hpp"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
Conv2d::Conv2d(
|
||||
shape2d inputSize,
|
||||
int inputChannels,
|
||||
shape2d kernelSize,
|
||||
shape2d stride,
|
||||
int numFilters,
|
||||
shape2d paddingSize,
|
||||
ActivationType activationType
|
||||
CUDANet::Shape input_shape,
|
||||
CUDANet::Shape kernel_shape,
|
||||
CUDANet::Shape stride_shape,
|
||||
CUDANet::Shape padding_shape,
|
||||
CUDANet::Backend* backend
|
||||
)
|
||||
: inputSize(inputSize),
|
||||
inputChannels(inputChannels),
|
||||
kernelSize(kernelSize),
|
||||
stride(stride),
|
||||
numFilters(numFilters),
|
||||
paddingSize(paddingSize) {
|
||||
outputSize = {
|
||||
(inputSize.first - kernelSize.first + 2 * paddingSize.first) /
|
||||
stride.first +
|
||||
1,
|
||||
(inputSize.second - kernelSize.second + 2 * paddingSize.second) /
|
||||
stride.second +
|
||||
1
|
||||
};
|
||||
: in_shape(input_shape),
|
||||
kernel_shape(kernel_shape),
|
||||
stride_shape(stride_shape),
|
||||
padding_shape(padding_shape),
|
||||
backend(backend) {
|
||||
if (in_shape.size() != 3) {
|
||||
throw std::runtime_error(
|
||||
std::format(
|
||||
"Invalid input shape. Expected 3 dims, got {}", in_shape
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
activation = new Activation(
|
||||
activationType, outputSize.first * outputSize.second * numFilters
|
||||
if (kernel_shape.size() != 3) {
|
||||
throw std::runtime_error(
|
||||
std::format(
|
||||
"Invalid kernel shape. Expected 3 dims, got {}", kernel_shape
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (stride_shape.size() != 2) {
|
||||
throw std::runtime_error(
|
||||
std::format(
|
||||
"Invalid stride shape. Expected 2 dims, got {}", stride_shape
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (padding_shape.size() != 2) {
|
||||
throw std::runtime_error(
|
||||
std::format(
|
||||
"Invalid padding shape. Expected 2 dims, got {}", padding_shape
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
size_t out_h = (in_shape[0] - kernel_shape[0] + 2 * padding_shape[0]) /
|
||||
stride_shape[0] +
|
||||
1;
|
||||
size_t out_w = (in_shape[1] - kernel_shape[1] + 2 * padding_shape[1]) /
|
||||
stride_shape[1] +
|
||||
1;
|
||||
out_shape.resize(3);
|
||||
out_shape[0] = out_h;
|
||||
out_shape[1] = out_w;
|
||||
out_shape[2] = kernel_shape[2];
|
||||
output = CUDANet::Tensor(
|
||||
Shape{out_shape[0] * out_shape[1] * out_shape[3]},
|
||||
CUDANet::DType::FLOAT32, backend
|
||||
);
|
||||
|
||||
weights.resize(
|
||||
kernelSize.first * kernelSize.second * inputChannels * numFilters
|
||||
weights = CUDANet::Tensor(
|
||||
Shape{
|
||||
kernel_shape[0] * kernel_shape[1] * kernel_shape[2] * in_shape[2]
|
||||
},
|
||||
CUDANet::DType::FLOAT32, backend
|
||||
);
|
||||
biases = CUDANet::Tensor(
|
||||
Shape{kernel_shape[2]}, CUDANet::DType::FLOAT32, backend
|
||||
);
|
||||
initializeWeights();
|
||||
|
||||
biases.resize(numFilters);
|
||||
initializeBiases();
|
||||
|
||||
#ifdef USE_CUDA
|
||||
initCUDA();
|
||||
toCuda();
|
||||
#endif
|
||||
weights.zero();
|
||||
biases.zero();
|
||||
}
|
||||
|
||||
Conv2d::~Conv2d() {
|
||||
#ifdef USE_CUDA
|
||||
delCUDA();
|
||||
#endif
|
||||
delete activation;
|
||||
Conv2d::~Conv2d() {}
|
||||
|
||||
CUDANet::Tensor& Conv2d::forward( CUDANet::Tensor& input) {
|
||||
output.zero();
|
||||
backend->conv2d(
|
||||
weights,
|
||||
biases,
|
||||
input,
|
||||
output,
|
||||
in_shape,
|
||||
padding_shape,
|
||||
kernel_shape,
|
||||
stride_shape,
|
||||
out_shape
|
||||
);
|
||||
return output;
|
||||
}
|
||||
|
||||
void Conv2d::initializeWeights() {
|
||||
std::fill(weights.begin(), weights.end(), 0.0f);
|
||||
CUDANet::Shape Conv2d::input_shape() {
|
||||
return in_shape;
|
||||
}
|
||||
|
||||
void Conv2d::initializeBiases() {
|
||||
std::fill(biases.begin(), biases.end(), 0.0f);
|
||||
CUDANet::Shape Conv2d::output_shape() {
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
void Conv2d::setWeights(const float* weights_input) {
|
||||
std::copy(weights_input, weights_input + weights.size(), weights.begin());
|
||||
#ifdef USE_CUDA
|
||||
toCuda();
|
||||
#endif
|
||||
size_t Conv2d::input_size() {
|
||||
return sizeof(float) * in_shape[0] * in_shape[1] * in_shape[2];
|
||||
}
|
||||
|
||||
std::vector<float> Conv2d::getWeights() {
|
||||
size_t Conv2d::output_size() {
|
||||
return sizeof(float) * out_shape[0] * out_shape[1] * out_shape[2];
|
||||
}
|
||||
|
||||
void Conv2d::set_weights(void* input) {
|
||||
weights.set_data<float>(static_cast<float*>(input));
|
||||
}
|
||||
|
||||
CUDANet::Tensor& Conv2d::get_weights() {
|
||||
return weights;
|
||||
}
|
||||
|
||||
void Conv2d::setBiases(const float* biases_input) {
|
||||
std::copy(biases_input, biases_input + biases.size(), biases.begin());
|
||||
#ifdef USE_CUDA
|
||||
toCuda();
|
||||
#endif
|
||||
void Conv2d::set_biases(void* input) {
|
||||
biases.set_data<float>(static_cast<float*>(input));
|
||||
}
|
||||
|
||||
std::vector<float> Conv2d::getBiases() {
|
||||
CUDANet::Tensor& Conv2d::get_biases() {
|
||||
return biases;
|
||||
}
|
||||
|
||||
float* Conv2d::forwardCPU(const float* input) {
|
||||
throw std::logic_error("Not implemented");
|
||||
}
|
||||
|
||||
float* Conv2d::forward(const float* input) {
|
||||
#ifdef USE_CUDA
|
||||
return forwardCUDA(input);
|
||||
#else
|
||||
return forwardCPU(input);
|
||||
#endif
|
||||
}
|
||||
|
||||
int Conv2d::getOutputSize() {
|
||||
return outputSize.first * outputSize.second * numFilters;
|
||||
}
|
||||
|
||||
int Conv2d::getInputSize() {
|
||||
return inputSize.first * inputSize.second * inputChannels;
|
||||
}
|
||||
|
||||
shape2d Conv2d::getOutputDims() {
|
||||
return outputSize;
|
||||
CUDANet::Shape Conv2d::get_padding_shape() {
|
||||
return padding_shape;
|
||||
}
|
||||
@@ -5,39 +5,35 @@
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
Dense::Dense(CUDANet::Backend* backend, CUDANet::Shape in, CUDANet::Shape out)
|
||||
Dense::Dense(CUDANet::Shape in, CUDANet::Shape out, CUDANet::Backend* backend)
|
||||
: backend(backend),
|
||||
in_shape(in),
|
||||
out_shape(out),
|
||||
weights(
|
||||
CUDANet::Tensor(Shape{in[0] * out[0]}, CUDANet::DType::FLOAT32, backend)
|
||||
),
|
||||
biases(CUDANet::Tensor(Shape{out[0]}, CUDANet::DType::FLOAT32, backend)),
|
||||
output(CUDANet::Tensor(Shape{out[0]}, CUDANet::DType::FLOAT32, backend)) {
|
||||
// Allocate memory for weights and biases
|
||||
out_shape(out) {
|
||||
|
||||
if (in.size() != 1) {
|
||||
throw std::runtime_error(
|
||||
std::format("Invalid shape. Expected [1], got {}", in)
|
||||
std::format("Invalid shape. Expected [1], got {}", in_shape)
|
||||
);
|
||||
}
|
||||
|
||||
if (out.size() != 1) {
|
||||
throw std::runtime_error(
|
||||
std::format("Invalid shape. Expected [1], got {}", out)
|
||||
std::format("Invalid shape. Expected [1], got {}", out_shape)
|
||||
);
|
||||
}
|
||||
|
||||
auto input_len = in[0];
|
||||
auto output_len = out[0];
|
||||
weights = CUDANet::Tensor(Shape{in[0] * out[0]}, CUDANet::DType::FLOAT32, backend);
|
||||
biases = CUDANet::Tensor(Shape{out[0]}, CUDANet::DType::FLOAT32, backend);
|
||||
output = CUDANet::Tensor(Shape{out[0]}, CUDANet::DType::FLOAT32, backend);
|
||||
|
||||
weights.zero();
|
||||
biases.zero();
|
||||
output.zero();
|
||||
}
|
||||
|
||||
Dense::~Dense() {}
|
||||
|
||||
CUDANet::Tensor& Dense::forward(const CUDANet::Tensor& input) {
|
||||
CUDANet::Tensor& Dense::forward(CUDANet::Tensor& input) {
|
||||
backend->dense(weights, biases, input, output, in_shape[0], out_shape[0]);
|
||||
return output;
|
||||
}
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
#include <stdexcept>
|
||||
|
||||
#include "input.hpp"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
Input::Input(int inputSize) : inputSize(inputSize) {
|
||||
#ifdef USE_CUDA
|
||||
initCUDA();
|
||||
#endif
|
||||
}
|
||||
|
||||
Input::~Input() {
|
||||
#ifdef USE_CUDA
|
||||
delCUDA();
|
||||
#endif
|
||||
}
|
||||
|
||||
float* Input::forwardCPU(const float* input) {
|
||||
throw std::logic_error("Not implemented");
|
||||
}
|
||||
|
||||
float* Input::forward(const float* input) {
|
||||
#ifdef USE_CUDA
|
||||
return forwardCUDA(input);
|
||||
#else
|
||||
return forwardCPU(input);
|
||||
#endif
|
||||
}
|
||||
|
||||
int Input::get_output_size() {
|
||||
return inputSize;
|
||||
}
|
||||
|
||||
int Input::getInputSize() {
|
||||
return inputSize;
|
||||
}
|
||||
70
src/layers/max_pool.cpp
Normal file
70
src/layers/max_pool.cpp
Normal file
@@ -0,0 +1,70 @@
|
||||
#include "max_pool.hpp"
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
MaxPool2d::MaxPool2d(
|
||||
CUDANet::Shape input_shape,
|
||||
CUDANet::Shape pooling_shape,
|
||||
CUDANet::Shape stride_shape,
|
||||
CUDANet::Shape padding_shape,
|
||||
CUDANet::Backend* backend
|
||||
)
|
||||
: in_shape(input_shape),
|
||||
pooling_shape(pooling_shape),
|
||||
stride_shape(stride_shape),
|
||||
padding_shape(padding_shape),
|
||||
backend(backend) {
|
||||
size_t out_h = (in_shape[0] + 2 * padding_shape[0] - pooling_shape[0]) /
|
||||
stride_shape[0] +
|
||||
1;
|
||||
size_t out_w = (in_shape[1] + 2 * padding_shape[1] - pooling_shape[1]) /
|
||||
stride_shape[1] +
|
||||
1;
|
||||
|
||||
out_shape.resize(3);
|
||||
out_shape[0] = out_h;
|
||||
out_shape[1] = out_w;
|
||||
out_shape[2] = in_shape[2];
|
||||
|
||||
output = CUDANet::Tensor(
|
||||
Shape{out_shape[0] * out_shape[1] * out_shape[3]},
|
||||
CUDANet::DType::FLOAT32, backend
|
||||
);
|
||||
}
|
||||
|
||||
MaxPool2d::~MaxPool2d() {}
|
||||
|
||||
CUDANet::Tensor& MaxPool2d::forward(CUDANet::Tensor& input) {
|
||||
output.zero();
|
||||
backend->maxPool2d(
|
||||
input, output, in_shape, pooling_shape, stride_shape, padding_shape,
|
||||
out_shape
|
||||
);
|
||||
return output;
|
||||
}
|
||||
|
||||
CUDANet::Shape MaxPool2d::input_shape() {
|
||||
return in_shape;
|
||||
}
|
||||
|
||||
CUDANet::Shape MaxPool2d::output_shape() {
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
size_t MaxPool2d::input_size() {
|
||||
return sizeof(float) * in_shape[0] * in_shape[1] * in_shape[2];
|
||||
}
|
||||
|
||||
size_t MaxPool2d::output_size() {
|
||||
return sizeof(float) * out_shape[0] * out_shape[1] * out_shape[2];
|
||||
}
|
||||
|
||||
void MaxPool2d::set_weights(void* input) {}
|
||||
|
||||
CUDANet::Tensor& MaxPool2d::get_weights() {}
|
||||
|
||||
void MaxPool2d::set_biases(void* input) {}
|
||||
|
||||
CUDANet::Tensor& MaxPool2d::get_biases() {}
|
||||
@@ -1,67 +0,0 @@
|
||||
#include "max_pooling.hpp"
|
||||
#include <stdexcept>
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
MaxPooling2d::MaxPooling2d(
|
||||
shape2d inputSize,
|
||||
int nChannels,
|
||||
shape2d poolingSize,
|
||||
shape2d stride,
|
||||
shape2d padding,
|
||||
ActivationType activationType
|
||||
)
|
||||
: inputSize(inputSize),
|
||||
nChannels(nChannels),
|
||||
poolingSize(poolingSize),
|
||||
stride(stride),
|
||||
padding(padding) {
|
||||
outputSize = {
|
||||
(inputSize.first + 2 * padding.first - poolingSize.first) /
|
||||
stride.first +
|
||||
1,
|
||||
(inputSize.second + 2 * padding.second - poolingSize.second) /
|
||||
stride.second +
|
||||
1
|
||||
};
|
||||
|
||||
activation = new Activation(
|
||||
activationType, outputSize.first * outputSize.second * nChannels
|
||||
);
|
||||
|
||||
#ifdef USE_CUDA
|
||||
initCUDA();
|
||||
#endif
|
||||
}
|
||||
|
||||
MaxPooling2d::~MaxPooling2d() {
|
||||
#ifdef USE_CUDA
|
||||
delCUDA();
|
||||
#endif
|
||||
delete activation;
|
||||
}
|
||||
|
||||
float* MaxPooling2d::forwardCPU(const float* input) {
|
||||
throw std::logic_error("Not implemented");
|
||||
}
|
||||
|
||||
float* MaxPooling2d::forward(const float* input) {
|
||||
#ifdef USE_CUDA
|
||||
return forwardCUDA(input);
|
||||
#else
|
||||
return forwardCPU(input);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
int MaxPooling2d::get_output_size() {
|
||||
return outputSize.first * outputSize.second * nChannels;
|
||||
}
|
||||
|
||||
int MaxPooling2d::getInputSize() {
|
||||
return inputSize.first * inputSize.second * nChannels;
|
||||
}
|
||||
|
||||
shape2d MaxPooling2d::getOutputDims() {
|
||||
return outputSize;
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
#include "output.hpp"
|
||||
#include <stdexcept>
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
|
||||
Output::Output(int inputSize) : inputSize(inputSize) {
|
||||
h_output = (float*) malloc(sizeof(float) * inputSize);
|
||||
}
|
||||
|
||||
Output::~Output() {
|
||||
free(h_output);
|
||||
}
|
||||
|
||||
float* Output::forwardCPU(const float* input) {
|
||||
throw std::logic_error("Not implemented");
|
||||
}
|
||||
|
||||
float* Output::forward(const float* input) {
|
||||
#ifdef USE_CUDA
|
||||
return forwardCUDA(input);
|
||||
#else
|
||||
return forwardCPU(input);
|
||||
#endif
|
||||
}
|
||||
|
||||
int Output::get_output_size() {
|
||||
return inputSize;
|
||||
}
|
||||
|
||||
|
||||
int Output::getInputSize() {
|
||||
return inputSize;
|
||||
}
|
||||
Reference in New Issue
Block a user