mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-05 17:34:21 +00:00
Abstract activation and implement softmax
This commit is contained in:
@@ -9,12 +9,13 @@ include_directories(${CUDAToolkit_INCLUDE_DIRS})
|
|||||||
|
|
||||||
set(LIBRARY_SOURCES
|
set(LIBRARY_SOURCES
|
||||||
src/utils/cuda_helper.cu
|
src/utils/cuda_helper.cu
|
||||||
src/kernels/activations.cu
|
src/kernels/activation_functions.cu
|
||||||
src/kernels/convolution.cu
|
src/kernels/convolution.cu
|
||||||
src/kernels/matmul.cu
|
src/kernels/matmul.cu
|
||||||
src/layers/dense.cu
|
src/layers/dense.cu
|
||||||
src/layers/conv2d.cu
|
src/layers/conv2d.cu
|
||||||
src/layers/input.cu
|
src/layers/input.cu
|
||||||
|
src/layers/activation.cu
|
||||||
)
|
)
|
||||||
|
|
||||||
set(CMAKE_CUDA_ARCHITECTURES 75)
|
set(CMAKE_CUDA_ARCHITECTURES 75)
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ Convolutional Neural Network inference library running on CUDA.
|
|||||||
- [CUDA](https://developer.nvidia.com/cuda-downloads)
|
- [CUDA](https://developer.nvidia.com/cuda-downloads)
|
||||||
- [Google Test](https://github.com/google/googletest) (for testing only)
|
- [Google Test](https://github.com/google/googletest) (for testing only)
|
||||||
|
|
||||||
**build and test**
|
**build**
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
mkdir build
|
mkdir build
|
||||||
@@ -33,8 +33,9 @@ cmake -S ..
|
|||||||
make
|
make
|
||||||
```
|
```
|
||||||
|
|
||||||
Run tests
|
**build and run tests**
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
|
make test_main
|
||||||
./test/test_main
|
./test/test_main
|
||||||
```
|
```
|
||||||
74
include/kernels/activation_functions.cuh
Normal file
74
include/kernels/activation_functions.cuh
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
#ifndef CUDANET_ACTIVATION_FUNCTIONS_H
|
||||||
|
#define CUDANET_ACTIVATION_FUNCTIONS_H
|
||||||
|
|
||||||
|
namespace CUDANet::Kernels {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Sigmoid activation function kernel
|
||||||
|
*
|
||||||
|
* @param src Pointer to the source array
|
||||||
|
* @param dst Pointer to the destination array
|
||||||
|
* @param len Length of the arrays
|
||||||
|
*/
|
||||||
|
__global__ void sigmoid(
|
||||||
|
const float* __restrict__ src,
|
||||||
|
float* __restrict__ dst,
|
||||||
|
const unsigned int len
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Relu activation function kernel
|
||||||
|
*
|
||||||
|
* @param src Pointer to the source array
|
||||||
|
* @param dst Pointer to the destination array
|
||||||
|
* @param len Length of the arrays
|
||||||
|
*/
|
||||||
|
__global__ void relu(
|
||||||
|
const float* __restrict__ src,
|
||||||
|
float* __restrict__ dst,
|
||||||
|
const unsigned int len
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Softmax activation exponentiation kernel
|
||||||
|
*
|
||||||
|
* @param src Pointer to the source array
|
||||||
|
* @param dst Pointer to the destination array
|
||||||
|
* @param len Length of the arrays
|
||||||
|
*/
|
||||||
|
__global__ void softmax_exp(
|
||||||
|
const float* __restrict__ src,
|
||||||
|
float* __restrict__ dst,
|
||||||
|
const unsigned int len
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief
|
||||||
|
*
|
||||||
|
* @param d_vector Device pointer to vector
|
||||||
|
* @param d_output Device pointer to output vector
|
||||||
|
* @param w Length of the vector
|
||||||
|
*/
|
||||||
|
__global__ void softmax_sum(
|
||||||
|
const float* __restrict__ d_vector,
|
||||||
|
float* __restrict__ d_output,
|
||||||
|
const unsigned int w
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Softmax activation function kernel
|
||||||
|
*
|
||||||
|
* @param src Pointer to the source array
|
||||||
|
* @param dst Pointer to the destination array
|
||||||
|
* @param len Length of the arrays
|
||||||
|
*/
|
||||||
|
__global__ void softmax_div(
|
||||||
|
const float* __restrict__ src,
|
||||||
|
float* __restrict__ dst,
|
||||||
|
const float* __restrict__ sum,
|
||||||
|
const unsigned int len
|
||||||
|
);
|
||||||
|
|
||||||
|
} // namespace CUDANet::Kernels
|
||||||
|
|
||||||
|
#endif // CUDANET_ACTIVATION_FUNCTIONS_H
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
#ifndef CUDANET_ACTIVATIONS_H
|
|
||||||
#define CUDANET_ACTIVATIONS_H
|
|
||||||
|
|
||||||
namespace CUDANet::Kernels {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Sigmoid activation function kernel
|
|
||||||
*
|
|
||||||
* @param src Pointer to the source array
|
|
||||||
* @param dst Pointer to the destination array
|
|
||||||
* @param len Length of the arrays
|
|
||||||
*/
|
|
||||||
__global__ void
|
|
||||||
sigmoid(const float* __restrict__ src, float* __restrict__ dst, int len);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Relu activation function kernel
|
|
||||||
*
|
|
||||||
* @param src Pointer to the source array
|
|
||||||
* @param dst Pointer to the destination array
|
|
||||||
* @param len Length of the arrays
|
|
||||||
*/
|
|
||||||
__global__ void
|
|
||||||
relu(const float* __restrict__ src, float* __restrict__ dst, int len);
|
|
||||||
|
|
||||||
} // namespace CUDANet::Kernels
|
|
||||||
|
|
||||||
#endif // CUDANET_ACTIVATIONS_H
|
|
||||||
@@ -35,19 +35,6 @@ __global__ void vec_vec_add(
|
|||||||
const unsigned int w
|
const unsigned int w
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief
|
|
||||||
*
|
|
||||||
* @param d_vector Device pointer to vector
|
|
||||||
* @param d_output Device pointer to output vector
|
|
||||||
* @param w Length of the vector
|
|
||||||
*/
|
|
||||||
__global__ void reduce_sum(
|
|
||||||
const float* __restrict__ d_vector,
|
|
||||||
float* __restrict__ d_output,
|
|
||||||
const unsigned int w
|
|
||||||
);
|
|
||||||
|
|
||||||
} // namespace CUDANet::Kernels
|
} // namespace CUDANet::Kernels
|
||||||
|
|
||||||
#endif // CUDANET_MATMUL_H
|
#endif // CUDANET_MATMUL_H
|
||||||
55
include/layers/activation.cuh
Normal file
55
include/layers/activation.cuh
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
#ifndef CUDANET_ACTIVATION_H
|
||||||
|
#define CUDANET_ACTIVATION_H
|
||||||
|
|
||||||
|
namespace CUDANet::Layers {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Activation functions
|
||||||
|
*
|
||||||
|
* SIGMOID: Sigmoid
|
||||||
|
* RELU: Rectified Linear Unit
|
||||||
|
* SOFTMAX: Softmax
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
enum ActivationType { SIGMOID, RELU, SOFTMAX, NONE };
|
||||||
|
|
||||||
|
class Activation {
|
||||||
|
public:
|
||||||
|
|
||||||
|
Activation() = default;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct a new Activation object
|
||||||
|
*
|
||||||
|
* @param activation Type of activation
|
||||||
|
* @param length Length of the input
|
||||||
|
*/
|
||||||
|
Activation(ActivationType activation, const unsigned int length);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Destroy the Activation object
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
~Activation();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Run the activation function on the input
|
||||||
|
*
|
||||||
|
* @param d_input Pointer to the input vector on the device
|
||||||
|
*/
|
||||||
|
void activate(float* d_input);
|
||||||
|
|
||||||
|
|
||||||
|
private:
|
||||||
|
ActivationType activationType;
|
||||||
|
unsigned int length;
|
||||||
|
unsigned int gridSize;
|
||||||
|
|
||||||
|
float* d_softmax_sum;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace CUDANet::Layers
|
||||||
|
|
||||||
|
#endif // CUDANET_ACTIVATION_H
|
||||||
@@ -4,7 +4,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "activations.cuh"
|
#include "activation.cuh"
|
||||||
#include "convolution.cuh"
|
#include "convolution.cuh"
|
||||||
#include "ilayer.cuh"
|
#include "ilayer.cuh"
|
||||||
|
|
||||||
@@ -23,18 +23,18 @@ class Conv2d : public ILayer {
|
|||||||
* @param inputChannels Number of channels in the input matrix
|
* @param inputChannels Number of channels in the input matrix
|
||||||
* @param kernelSize Width and height of the convolution kernel
|
* @param kernelSize Width and height of the convolution kernel
|
||||||
* @param stride Convolution stride
|
* @param stride Convolution stride
|
||||||
* @param padding Padding type ('SAME' or 'VALID')
|
|
||||||
* @param numFilters Number of output filters
|
* @param numFilters Number of output filters
|
||||||
* @param activation Activation function ('RELU', 'SIGMOID' or 'NONE')
|
* @param padding Padding type ('SAME' or 'VALID')
|
||||||
|
* @param activationType Activation function type ('RELU', 'SIGMOID', 'SOFTMAX' or 'NONE')
|
||||||
*/
|
*/
|
||||||
Conv2d(
|
Conv2d(
|
||||||
int inputSize,
|
int inputSize,
|
||||||
int inputChannels,
|
int inputChannels,
|
||||||
int kernelSize,
|
int kernelSize,
|
||||||
int stride,
|
int stride,
|
||||||
Layers::Padding padding,
|
int numFilters,
|
||||||
int numFilters,
|
Layers::Padding padding,
|
||||||
Layers::Activation activation
|
Layers::ActivationType activationType
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -67,17 +67,21 @@ class Conv2d : public ILayer {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Get the output width (/ height) of the layer
|
* @brief Get the output width (/ height) of the layer
|
||||||
*
|
*
|
||||||
* @return int
|
* @return int
|
||||||
*/
|
*/
|
||||||
int getOutputSize() { return outputSize; }
|
int getOutputSize() {
|
||||||
|
return outputSize;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Get the padding size of the layer
|
* @brief Get the padding size of the layer
|
||||||
*
|
*
|
||||||
* @return int
|
* @return int
|
||||||
*/
|
*/
|
||||||
int getPaddingSize() { return paddingSize; }
|
int getPaddingSize() {
|
||||||
|
return paddingSize;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Inputs
|
// Inputs
|
||||||
|
|||||||
@@ -20,9 +20,9 @@ class Dense : public ILayer {
|
|||||||
*
|
*
|
||||||
* @param inputSize Size of the input vector
|
* @param inputSize Size of the input vector
|
||||||
* @param outputSize Size of the output vector
|
* @param outputSize Size of the output vector
|
||||||
* @param activation Activation function ('RELU', 'SIGMOID' or 'NONE')
|
* @param activationType Activation function type ('RELU', 'SIGMOID', 'SOFTMAX' or 'NONE')
|
||||||
*/
|
*/
|
||||||
Dense(int inputSize, int outputSize, Layers::Activation activation);
|
Dense(int inputSize, int outputSize, Layers::ActivationType activationType);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Destroy the Dense layer
|
* @brief Destroy the Dense layer
|
||||||
|
|||||||
@@ -6,15 +6,6 @@
|
|||||||
|
|
||||||
namespace CUDANet::Layers {
|
namespace CUDANet::Layers {
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Activation functions
|
|
||||||
*
|
|
||||||
* SIGMOID: Sigmoid
|
|
||||||
* RELU: Rectified Linear Unit
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
enum Activation { SIGMOID, RELU, NONE };
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Padding types
|
* @brief Padding types
|
||||||
*
|
*
|
||||||
@@ -85,7 +76,6 @@ class ILayer {
|
|||||||
std::vector<float> weights;
|
std::vector<float> weights;
|
||||||
std::vector<float> biases;
|
std::vector<float> biases;
|
||||||
|
|
||||||
Layers::Activation activation;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace CUDANet::Layers
|
} // namespace CUDANet::Layers
|
||||||
|
|||||||
79
src/kernels/activation_functions.cu
Normal file
79
src/kernels/activation_functions.cu
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
#include <functional>
|
||||||
|
|
||||||
|
#include "activation_functions.cuh"
|
||||||
|
#include "cuda_helper.cuh"
|
||||||
|
|
||||||
|
__global__ void CUDANet::Kernels::sigmoid(
|
||||||
|
const float* __restrict__ src,
|
||||||
|
float* __restrict__ dst,
|
||||||
|
const unsigned int len
|
||||||
|
) {
|
||||||
|
int stride = gridDim.x * blockDim.x;
|
||||||
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
for (int i = tid; i < len; i += stride) {
|
||||||
|
dst[i] = 1.0 / (1.0 + exp(-src[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void CUDANet::Kernels::relu(
|
||||||
|
const float* __restrict__ src,
|
||||||
|
float* __restrict__ dst,
|
||||||
|
const unsigned int len
|
||||||
|
) {
|
||||||
|
int stride = gridDim.x * blockDim.x;
|
||||||
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
for (int i = tid; i < len; i += stride) {
|
||||||
|
dst[i] = src[i] < 0.0 ? 0.0 : src[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void CUDANet::Kernels::softmax_exp(
|
||||||
|
const float* __restrict__ src,
|
||||||
|
float* __restrict__ dst,
|
||||||
|
const unsigned int len
|
||||||
|
) {
|
||||||
|
int stride = gridDim.x * blockDim.x;
|
||||||
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
for (int i = tid; i < len; i += stride) {
|
||||||
|
dst[i] = exp(src[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void CUDANet::Kernels::softmax_sum(
|
||||||
|
const float* __restrict__ d_vector,
|
||||||
|
float* __restrict__ d_output,
|
||||||
|
const unsigned int w
|
||||||
|
) {
|
||||||
|
__shared__ float partial_sum[BLOCK_SIZE];
|
||||||
|
int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
|
||||||
|
partial_sum[threadIdx.x] = d_vector[i] + d_vector[i + blockDim.x];
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||||
|
if (threadIdx.x < s) {
|
||||||
|
partial_sum[threadIdx.x] += partial_sum[threadIdx.x + s];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
d_output[blockIdx.x] = partial_sum[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void CUDANet::Kernels::softmax_div(
|
||||||
|
const float* __restrict__ src,
|
||||||
|
float* __restrict__ dst,
|
||||||
|
const float* __restrict__ sum,
|
||||||
|
const unsigned int len
|
||||||
|
) {
|
||||||
|
int stride = gridDim.x * blockDim.x;
|
||||||
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
for (int i = tid; i < len; i += stride) {
|
||||||
|
dst[i] = src[i] / sum[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
#include <functional>
|
|
||||||
|
|
||||||
#include "activations.cuh"
|
|
||||||
|
|
||||||
__global__ void CUDANet::Kernels::sigmoid(
|
|
||||||
const float* __restrict__ src,
|
|
||||||
float* __restrict__ dst,
|
|
||||||
int len
|
|
||||||
) {
|
|
||||||
int stride = gridDim.x * blockDim.x;
|
|
||||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
|
||||||
|
|
||||||
for (int i = tid; i < len; i += stride) {
|
|
||||||
dst[i] = 1.0 / (1.0 + exp(-src[i]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__global__ void CUDANet::Kernels::relu(
|
|
||||||
const float* __restrict__ src,
|
|
||||||
float* __restrict__ dst,
|
|
||||||
int len
|
|
||||||
) {
|
|
||||||
int stride = gridDim.x * blockDim.x;
|
|
||||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
|
||||||
|
|
||||||
for (int i = tid; i < len; i += stride) {
|
|
||||||
dst[i] = src[i] < 0.0 ? 0.0 : src[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -47,15 +47,3 @@ __global__ void CUDANet::Kernels::vec_vec_add(
|
|||||||
}
|
}
|
||||||
d_output[tid] = d_vector1[tid] + d_vector2[tid];
|
d_output[tid] = d_vector1[tid] + d_vector2[tid];
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void CUDANet::Kernels::reduce_sum(
|
|
||||||
const float* __restrict__ d_vector,
|
|
||||||
float* __restrict__ d_output,
|
|
||||||
const unsigned int w
|
|
||||||
) {
|
|
||||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
|
||||||
|
|
||||||
__shared__ float shared[BLOCK_SIZE];
|
|
||||||
shared[threadIdx.x] = d_vector[tid];
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
60
src/layers/activation.cu
Normal file
60
src/layers/activation.cu
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
#include "activation.cuh"
|
||||||
|
|
||||||
|
#include "cuda_helper.cuh"
|
||||||
|
#include "activation_functions.cuh"
|
||||||
|
|
||||||
|
using namespace CUDANet;
|
||||||
|
|
||||||
|
Layers::Activation::Activation(ActivationType activation, const unsigned int length)
|
||||||
|
: activationType(activation), length(length) {
|
||||||
|
|
||||||
|
if (activationType == SOFTMAX) {
|
||||||
|
d_softmax_sum = nullptr;
|
||||||
|
CUDA_CHECK(cudaMalloc((void**)&d_softmax_sum, sizeof(float) * length));
|
||||||
|
}
|
||||||
|
|
||||||
|
gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
|
}
|
||||||
|
|
||||||
|
Layers::Activation::~Activation() {
|
||||||
|
if (activationType == SOFTMAX) {
|
||||||
|
cudaFree(d_softmax_sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Layers::Activation::activate(float* __restrict__ d_input) {
|
||||||
|
|
||||||
|
switch (activationType) {
|
||||||
|
case SIGMOID:
|
||||||
|
Kernels::sigmoid<<<gridSize, BLOCK_SIZE>>>(
|
||||||
|
d_input, d_input, length
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case RELU:
|
||||||
|
Kernels::relu<<<gridSize, BLOCK_SIZE>>>(
|
||||||
|
d_input, d_input, length
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case SOFTMAX:
|
||||||
|
Kernels::softmax_exp<<<gridSize, BLOCK_SIZE>>>(
|
||||||
|
d_input, d_input, length
|
||||||
|
);
|
||||||
|
|
||||||
|
Kernels::softmax_sum<<<gridSize, BLOCK_SIZE>>>(
|
||||||
|
d_input, d_softmax_sum, length
|
||||||
|
);
|
||||||
|
|
||||||
|
Kernels::softmax_sum<<<1, BLOCK_SIZE>>>(
|
||||||
|
d_softmax_sum, d_softmax_sum, length
|
||||||
|
);
|
||||||
|
|
||||||
|
Kernels::softmax_div<<<gridSize, BLOCK_SIZE>>>(
|
||||||
|
d_input, d_input, d_softmax_sum, length
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "activations.cuh"
|
#include "activation.cuh"
|
||||||
#include "conv2d.cuh"
|
#include "conv2d.cuh"
|
||||||
#include "convolution.cuh"
|
#include "convolution.cuh"
|
||||||
#include "cuda_helper.cuh"
|
#include "cuda_helper.cuh"
|
||||||
@@ -10,20 +10,19 @@
|
|||||||
using namespace CUDANet;
|
using namespace CUDANet;
|
||||||
|
|
||||||
Layers::Conv2d::Conv2d(
|
Layers::Conv2d::Conv2d(
|
||||||
int inputSize,
|
int inputSize,
|
||||||
int inputChannels,
|
int inputChannels,
|
||||||
int kernelSize,
|
int kernelSize,
|
||||||
int stride,
|
int stride,
|
||||||
Layers::Padding padding,
|
int numFilters,
|
||||||
int numFilters,
|
Layers::Padding padding,
|
||||||
Layers::Activation activation
|
Layers::ActivationType activationType
|
||||||
)
|
)
|
||||||
: inputSize(inputSize),
|
: inputSize(inputSize),
|
||||||
inputChannels(inputChannels),
|
inputChannels(inputChannels),
|
||||||
kernelSize(kernelSize),
|
kernelSize(kernelSize),
|
||||||
stride(stride),
|
stride(stride),
|
||||||
numFilters(numFilters),
|
numFilters(numFilters) {
|
||||||
activation(activation) {
|
|
||||||
switch (padding) {
|
switch (padding) {
|
||||||
case SAME:
|
case SAME:
|
||||||
outputSize = inputSize;
|
outputSize = inputSize;
|
||||||
@@ -39,10 +38,13 @@ Layers::Conv2d::Conv2d(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
activation = Layers::Activation(
|
||||||
|
activationType, outputSize * outputSize * numFilters
|
||||||
|
);
|
||||||
|
|
||||||
d_output = nullptr;
|
d_output = nullptr;
|
||||||
CUDA_CHECK(cudaMalloc(
|
CUDA_CHECK(cudaMalloc(
|
||||||
(void**)&d_output,
|
(void**)&d_output, sizeof(float) * outputSize * outputSize * numFilters
|
||||||
sizeof(float) * outputSize * outputSize * numFilters
|
|
||||||
));
|
));
|
||||||
|
|
||||||
weights.resize(kernelSize * kernelSize * inputChannels * numFilters);
|
weights.resize(kernelSize * kernelSize * inputChannels * numFilters);
|
||||||
@@ -131,18 +133,8 @@ float* Layers::Conv2d::forward(const float* d_input) {
|
|||||||
d_biases, d_output, d_output, biases.size()
|
d_biases, d_output, d_output, biases.size()
|
||||||
);
|
);
|
||||||
|
|
||||||
switch (activation) {
|
// Apply activation
|
||||||
case SIGMOID:
|
activation.activate(d_output);
|
||||||
Kernels::sigmoid<<<1, outputSize>>>(d_output, d_output, outputSize);
|
|
||||||
break;
|
|
||||||
|
|
||||||
case RELU:
|
|
||||||
Kernels::relu<<<1, outputSize>>>(d_output, d_output, outputSize);
|
|
||||||
break;
|
|
||||||
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
#include <functional>
|
#include <functional>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
#include "activations.cuh"
|
#include "activation.cuh"
|
||||||
#include "cuda_helper.cuh"
|
#include "cuda_helper.cuh"
|
||||||
#include "dense.cuh"
|
#include "dense.cuh"
|
||||||
#include "matmul.cuh"
|
#include "matmul.cuh"
|
||||||
@@ -15,13 +15,15 @@ using namespace CUDANet;
|
|||||||
Layers::Dense::Dense(
|
Layers::Dense::Dense(
|
||||||
int inputSize,
|
int inputSize,
|
||||||
int outputSize,
|
int outputSize,
|
||||||
Layers::Activation activation
|
Layers::ActivationType activationType
|
||||||
)
|
)
|
||||||
: inputSize(inputSize), outputSize(outputSize), activation(activation) {
|
: inputSize(inputSize), outputSize(outputSize) {
|
||||||
// Allocate memory for weights and biases
|
// Allocate memory for weights and biases
|
||||||
weights.resize(outputSize * inputSize);
|
weights.resize(outputSize * inputSize);
|
||||||
biases.resize(outputSize);
|
biases.resize(outputSize);
|
||||||
|
|
||||||
|
activation = Layers::Activation(activationType, outputSize);
|
||||||
|
|
||||||
initializeWeights();
|
initializeWeights();
|
||||||
initializeBiases();
|
initializeBiases();
|
||||||
|
|
||||||
@@ -69,22 +71,7 @@ float* Layers::Dense::forward(const float* d_input) {
|
|||||||
d_biases, d_output, d_output, outputSize
|
d_biases, d_output, d_output, outputSize
|
||||||
);
|
);
|
||||||
|
|
||||||
switch (activation) {
|
activation.activate(d_output);
|
||||||
case SIGMOID:
|
|
||||||
Kernels::sigmoid<<<biasGridSize, BLOCK_SIZE>>>(
|
|
||||||
d_output, d_output, outputSize
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
|
|
||||||
case RELU:
|
|
||||||
Kernels::relu<<<biasGridSize, BLOCK_SIZE>>>(
|
|
||||||
d_output, d_output, outputSize
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,11 @@ find_package(GTest REQUIRED)
|
|||||||
include_directories(${GTEST_INCLUDE_DIRS})
|
include_directories(${GTEST_INCLUDE_DIRS})
|
||||||
|
|
||||||
add_executable(test_main
|
add_executable(test_main
|
||||||
|
EXCLUDE_FROM_ALL
|
||||||
layers/test_dense.cu
|
layers/test_dense.cu
|
||||||
layers/test_conv2d.cu
|
layers/test_conv2d.cu
|
||||||
layers/test_input.cu
|
layers/test_input.cu
|
||||||
kernels/test_activations.cu
|
kernels/test_activation_functions.cu
|
||||||
kernels/test_padding.cu
|
kernels/test_padding.cu
|
||||||
kernels/test_matmul.cu
|
kernels/test_matmul.cu
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
#include "activations.cuh"
|
#include "activation_functions.cuh"
|
||||||
|
|
||||||
TEST(ActivationsTest, SigmoidSanityCheck) {
|
TEST(ActivationsTest, SigmoidSanityCheck) {
|
||||||
|
|
||||||
@@ -8,21 +8,21 @@
|
|||||||
class Conv2dTest : public ::testing::Test {
|
class Conv2dTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
CUDANet::Layers::Conv2d commonTestSetup(
|
CUDANet::Layers::Conv2d commonTestSetup(
|
||||||
int inputSize,
|
int inputSize,
|
||||||
int inputChannels,
|
int inputChannels,
|
||||||
int kernelSize,
|
int kernelSize,
|
||||||
int stride,
|
int stride,
|
||||||
CUDANet::Layers::Padding padding,
|
int numFilters,
|
||||||
int numFilters,
|
CUDANet::Layers::Padding padding,
|
||||||
CUDANet::Layers::Activation activation,
|
CUDANet::Layers::ActivationType activationType,
|
||||||
std::vector<float>& input,
|
std::vector<float>& input,
|
||||||
float* kernels,
|
float* kernels,
|
||||||
float*& d_input
|
float*& d_input
|
||||||
) {
|
) {
|
||||||
// Create Conv2d layer
|
// Create Conv2d layer
|
||||||
CUDANet::Layers::Conv2d conv2d(
|
CUDANet::Layers::Conv2d conv2d(
|
||||||
inputSize, inputChannels, kernelSize, stride, padding, numFilters,
|
inputSize, inputChannels, kernelSize, stride, numFilters, padding,
|
||||||
activation
|
activationType
|
||||||
);
|
);
|
||||||
|
|
||||||
conv2d.setWeights(kernels);
|
conv2d.setWeights(kernels);
|
||||||
@@ -53,13 +53,14 @@ class Conv2dTest : public ::testing::Test {
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(Conv2dTest, SimpleTest) {
|
TEST_F(Conv2dTest, SimpleTest) {
|
||||||
int inputSize = 4;
|
int inputSize = 4;
|
||||||
int inputChannels = 1;
|
int inputChannels = 1;
|
||||||
int kernelSize = 2;
|
int kernelSize = 2;
|
||||||
int stride = 1;
|
int stride = 1;
|
||||||
CUDANet::Layers::Padding padding = CUDANet::Layers::Padding::VALID;
|
int numFilters = 1;
|
||||||
int numFilters = 1;
|
CUDANet::Layers::Padding padding = CUDANet::Layers::Padding::VALID;
|
||||||
CUDANet::Layers::Activation activation = CUDANet::Layers::Activation::NONE;
|
CUDANet::Layers::ActivationType activationType =
|
||||||
|
CUDANet::Layers::ActivationType::NONE;
|
||||||
|
|
||||||
std::vector<float> input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
|
std::vector<float> input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
|
||||||
7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
|
7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
|
||||||
@@ -75,8 +76,8 @@ TEST_F(Conv2dTest, SimpleTest) {
|
|||||||
float* d_output;
|
float* d_output;
|
||||||
|
|
||||||
CUDANet::Layers::Conv2d conv2d = commonTestSetup(
|
CUDANet::Layers::Conv2d conv2d = commonTestSetup(
|
||||||
inputSize, inputChannels, kernelSize, stride, padding, numFilters,
|
inputSize, inputChannels, kernelSize, stride, numFilters, padding,
|
||||||
activation, input, kernels.data(), d_input
|
activationType, input, kernels.data(), d_input
|
||||||
);
|
);
|
||||||
|
|
||||||
int outputSize = (inputSize - kernelSize) / stride + 1;
|
int outputSize = (inputSize - kernelSize) / stride + 1;
|
||||||
@@ -102,13 +103,14 @@ TEST_F(Conv2dTest, SimpleTest) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(Conv2dTest, PaddedTest) {
|
TEST_F(Conv2dTest, PaddedTest) {
|
||||||
int inputSize = 5;
|
int inputSize = 5;
|
||||||
int inputChannels = 3;
|
int inputChannels = 3;
|
||||||
int kernelSize = 3;
|
int kernelSize = 3;
|
||||||
int stride = 1;
|
int stride = 1;
|
||||||
CUDANet::Layers::Padding padding = CUDANet::Layers::Padding::SAME;
|
int numFilters = 2;
|
||||||
int numFilters = 2;
|
CUDANet::Layers::Padding padding = CUDANet::Layers::Padding::SAME;
|
||||||
CUDANet::Layers::Activation activation = CUDANet::Layers::Activation::NONE;
|
CUDANet::Layers::ActivationType activationType =
|
||||||
|
CUDANet::Layers::ActivationType::NONE;
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
std::vector<float> input = {
|
std::vector<float> input = {
|
||||||
@@ -164,8 +166,8 @@ TEST_F(Conv2dTest, PaddedTest) {
|
|||||||
float* d_output;
|
float* d_output;
|
||||||
|
|
||||||
CUDANet::Layers::Conv2d conv2d = commonTestSetup(
|
CUDANet::Layers::Conv2d conv2d = commonTestSetup(
|
||||||
inputSize, inputChannels, kernelSize, stride, padding, numFilters,
|
inputSize, inputChannels, kernelSize, stride, numFilters, padding,
|
||||||
activation, input, kernels.data(), d_input
|
activationType, input, kernels.data(), d_input
|
||||||
);
|
);
|
||||||
|
|
||||||
EXPECT_EQ(inputSize, conv2d.getOutputSize());
|
EXPECT_EQ(inputSize, conv2d.getOutputSize());
|
||||||
@@ -203,13 +205,14 @@ TEST_F(Conv2dTest, PaddedTest) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(Conv2dTest, StridedPaddedConvolution) {
|
TEST_F(Conv2dTest, StridedPaddedConvolution) {
|
||||||
int inputSize = 5;
|
int inputSize = 5;
|
||||||
int inputChannels = 2;
|
int inputChannels = 2;
|
||||||
int kernelSize = 3;
|
int kernelSize = 3;
|
||||||
int stride = 2;
|
int stride = 2;
|
||||||
int numFilters = 2;
|
int numFilters = 2;
|
||||||
CUDANet::Layers::Padding padding = CUDANet::Layers::Padding::SAME;
|
CUDANet::Layers::Padding padding = CUDANet::Layers::Padding::SAME;
|
||||||
CUDANet::Layers::Activation activation = CUDANet::Layers::Activation::RELU;
|
CUDANet::Layers::ActivationType activationType =
|
||||||
|
CUDANet::Layers::ActivationType::RELU;
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
std::vector<float> input = {
|
std::vector<float> input = {
|
||||||
@@ -250,8 +253,8 @@ TEST_F(Conv2dTest, StridedPaddedConvolution) {
|
|||||||
float* d_output;
|
float* d_output;
|
||||||
|
|
||||||
CUDANet::Layers::Conv2d conv2d = commonTestSetup(
|
CUDANet::Layers::Conv2d conv2d = commonTestSetup(
|
||||||
inputSize, inputChannels, kernelSize, stride, padding, numFilters,
|
inputSize, inputChannels, kernelSize, stride, numFilters, padding,
|
||||||
activation, input, kernels.data(), d_input
|
activationType, input, kernels.data(), d_input
|
||||||
);
|
);
|
||||||
|
|
||||||
EXPECT_EQ(inputSize, conv2d.getOutputSize());
|
EXPECT_EQ(inputSize, conv2d.getOutputSize());
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
#include "activations.cuh"
|
#include "activation.cuh"
|
||||||
#include "dense.cuh"
|
#include "dense.cuh"
|
||||||
|
|
||||||
class DenseLayerTest : public ::testing::Test {
|
class DenseLayerTest : public ::testing::Test {
|
||||||
@@ -15,10 +15,10 @@ class DenseLayerTest : public ::testing::Test {
|
|||||||
float* weights,
|
float* weights,
|
||||||
float* biases,
|
float* biases,
|
||||||
float*& d_input,
|
float*& d_input,
|
||||||
CUDANet::Layers::Activation activation
|
CUDANet::Layers::ActivationType activationType
|
||||||
) {
|
) {
|
||||||
// Create Dense layer
|
// Create Dense layer
|
||||||
CUDANet::Layers::Dense denseLayer(inputSize, outputSize, activation);
|
CUDANet::Layers::Dense denseLayer(inputSize, outputSize, activationType);
|
||||||
|
|
||||||
// Set weights and biases
|
// Set weights and biases
|
||||||
denseLayer.setWeights(weights);
|
denseLayer.setWeights(weights);
|
||||||
@@ -53,7 +53,7 @@ TEST_F(DenseLayerTest, Init) {
|
|||||||
int outputSize = j;
|
int outputSize = j;
|
||||||
|
|
||||||
CUDANet::Layers::Dense denseLayer(
|
CUDANet::Layers::Dense denseLayer(
|
||||||
inputSize, outputSize, CUDANet::Layers::Activation::SIGMOID
|
inputSize, outputSize, CUDANet::Layers::ActivationType::SIGMOID
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -74,7 +74,7 @@ TEST_F(DenseLayerTest, setWeights) {
|
|||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
CUDANet::Layers::Dense denseLayer(
|
CUDANet::Layers::Dense denseLayer(
|
||||||
inputSize, outputSize, CUDANet::Layers::Activation::SIGMOID
|
inputSize, outputSize, CUDANet::Layers::ActivationType::SIGMOID
|
||||||
);
|
);
|
||||||
|
|
||||||
denseLayer.setWeights(weights.data());
|
denseLayer.setWeights(weights.data());
|
||||||
@@ -101,7 +101,7 @@ TEST_F(DenseLayerTest, ForwardUnitWeightMatrixLinear) {
|
|||||||
|
|
||||||
CUDANet::Layers::Dense denseLayer = commonTestSetup(
|
CUDANet::Layers::Dense denseLayer = commonTestSetup(
|
||||||
inputSize, outputSize, input, weights.data(), biases.data(), d_input,
|
inputSize, outputSize, input, weights.data(), biases.data(), d_input,
|
||||||
CUDANet::Layers::Activation::NONE
|
CUDANet::Layers::ActivationType::NONE
|
||||||
);
|
);
|
||||||
d_output = denseLayer.forward(d_input);
|
d_output = denseLayer.forward(d_input);
|
||||||
|
|
||||||
@@ -142,7 +142,7 @@ TEST_F(DenseLayerTest, ForwardRandomWeightMatrixRelu) {
|
|||||||
|
|
||||||
CUDANet::Layers::Dense denseLayer = commonTestSetup(
|
CUDANet::Layers::Dense denseLayer = commonTestSetup(
|
||||||
inputSize, outputSize, input, weights.data(), biases.data(), d_input,
|
inputSize, outputSize, input, weights.data(), biases.data(), d_input,
|
||||||
CUDANet::Layers::Activation::RELU
|
CUDANet::Layers::ActivationType::RELU
|
||||||
);
|
);
|
||||||
|
|
||||||
d_output = denseLayer.forward(d_input);
|
d_output = denseLayer.forward(d_input);
|
||||||
@@ -187,7 +187,7 @@ TEST_F(DenseLayerTest, ForwardRandomWeightMatrixSigmoid) {
|
|||||||
|
|
||||||
CUDANet::Layers::Dense denseLayer = commonTestSetup(
|
CUDANet::Layers::Dense denseLayer = commonTestSetup(
|
||||||
inputSize, outputSize, input, weights.data(), biases.data(), d_input,
|
inputSize, outputSize, input, weights.data(), biases.data(), d_input,
|
||||||
CUDANet::Layers::Activation::SIGMOID
|
CUDANet::Layers::ActivationType::SIGMOID
|
||||||
);
|
);
|
||||||
|
|
||||||
d_output = denseLayer.forward(d_input);
|
d_output = denseLayer.forward(d_input);
|
||||||
|
|||||||
Reference in New Issue
Block a user