mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 17:54:27 +00:00
Refactor layers
This commit is contained in:
@@ -7,22 +7,19 @@ project(CUDANet
|
|||||||
find_package(CUDAToolkit REQUIRED)
|
find_package(CUDAToolkit REQUIRED)
|
||||||
include_directories(${CUDAToolkit_INCLUDE_DIRS})
|
include_directories(${CUDAToolkit_INCLUDE_DIRS})
|
||||||
|
|
||||||
|
file(GLOB_RECURSE LIBRARY_SOURCES
|
||||||
|
src/*.cu
|
||||||
|
src/utils/*.cu
|
||||||
|
src/kernels/*.cu
|
||||||
|
src/layers/*.cu)
|
||||||
|
|
||||||
set(LIBRARY_SOURCES
|
set(LIBRARY_SOURCES
|
||||||
src/utils/cuda_helper.cu
|
${LIBRARY_SOURCES}
|
||||||
src/kernels/activation_functions.cu
|
|
||||||
src/kernels/convolution.cu
|
|
||||||
src/kernels/matmul.cu
|
|
||||||
src/layers/add.cu
|
|
||||||
src/layers/dense.cu
|
|
||||||
src/layers/conv2d.cu
|
|
||||||
src/layers/concat.cu
|
|
||||||
src/layers/input.cu
|
|
||||||
src/layers/activation.cu
|
|
||||||
)
|
)
|
||||||
|
|
||||||
set(CMAKE_CUDA_ARCHITECTURES 75)
|
set(CMAKE_CUDA_ARCHITECTURES 75)
|
||||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -arch=sm_75)
|
# set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -arch=sm_75)
|
||||||
|
|
||||||
# Build static library
|
# Build static library
|
||||||
add_library(${PROJECT_NAME} STATIC ${LIBRARY_SOURCES})
|
add_library(${PROJECT_NAME} STATIC ${LIBRARY_SOURCES})
|
||||||
|
|||||||
@@ -13,6 +13,10 @@ namespace CUDANet::Layers {
|
|||||||
*/
|
*/
|
||||||
enum ActivationType { SIGMOID, RELU, SOFTMAX, NONE };
|
enum ActivationType { SIGMOID, RELU, SOFTMAX, NONE };
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Utility class that performs activation
|
||||||
|
*
|
||||||
|
*/
|
||||||
class Activation {
|
class Activation {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
|||||||
@@ -19,14 +19,13 @@ class Add {
|
|||||||
~Add();
|
~Add();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Adds the two inputs
|
* @brief Adds first input to second input
|
||||||
*
|
*
|
||||||
* @param d_inputA Device pointer to the first input
|
* @param d_inputA Device pointer to the first input
|
||||||
* @param d_inputB Device pointer to the second input
|
* @param d_inputB Device pointer to the second input
|
||||||
*
|
*
|
||||||
* @return Device pointer to the output
|
|
||||||
*/
|
*/
|
||||||
float* forward(const float* d_inputA, const float* d_inputB);
|
void forward(const float* d_inputA, const float* d_inputB);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int inputSize;
|
int inputSize;
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
#include "activation.cuh"
|
#include "activation.cuh"
|
||||||
#include "convolution.cuh"
|
#include "convolution.cuh"
|
||||||
#include "weighted_layer.cuh"
|
#include "layer.cuh"
|
||||||
|
|
||||||
namespace CUDANet::Layers {
|
namespace CUDANet::Layers {
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "weighted_layer.cuh"
|
#include "layer.cuh"
|
||||||
|
|
||||||
namespace CUDANet::Layers {
|
namespace CUDANet::Layers {
|
||||||
|
|
||||||
|
|||||||
@@ -16,15 +16,36 @@ namespace CUDANet::Layers {
|
|||||||
enum Padding { SAME, VALID };
|
enum Padding { SAME, VALID };
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Base class for all layers
|
* @brief Basic Sequential Layer
|
||||||
|
*
|
||||||
*/
|
*/
|
||||||
class WeightedLayer {
|
class SequentialLayer {
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Destroy the Sequential Layer
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
virtual ~SequentialLayer() {};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Forward propagation virtual function
|
||||||
|
*
|
||||||
|
* @param input Device pointer to the input
|
||||||
|
* @return float* Device pointer to the output
|
||||||
|
*/
|
||||||
|
virtual float* forward(const float* input) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Base class for layers with weights and biases
|
||||||
|
*/
|
||||||
|
class WeightedLayer : public SequentialLayer {
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
* @brief Destroy the ILayer object
|
* @brief Destroy the ILayer object
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
virtual ~WeightedLayer() {}
|
virtual ~WeightedLayer() {};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Virtual function for forward pass
|
* @brief Virtual function for forward pass
|
||||||
@@ -32,24 +53,23 @@ class WeightedLayer {
|
|||||||
* @param input (Device) Pointer to the input
|
* @param input (Device) Pointer to the input
|
||||||
* @return float* Device pointer to the output
|
* @return float* Device pointer to the output
|
||||||
*/
|
*/
|
||||||
virtual float* forward(const float* input) = 0;
|
virtual float* forward(const float* input) = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Virtual function for setting weights
|
* @brief Virtual function for setting weights
|
||||||
*
|
*
|
||||||
* @param weights Pointer to the weights
|
* @param weights Pointer to the weights
|
||||||
*/
|
*/
|
||||||
virtual void setWeights(const float* weights) = 0;
|
virtual void setWeights(const float* weights) = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Virtual function for setting biases
|
* @brief Virtual function for setting biases
|
||||||
*
|
*
|
||||||
* @param biases Pointer to the biases
|
* @param biases Pointer to the biases
|
||||||
*/
|
*/
|
||||||
virtual void setBiases(const float* biases) = 0;
|
virtual void setBiases(const float* biases) = 0;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Initialize the weights
|
* @brief Initialize the weights
|
||||||
*/
|
*/
|
||||||
@@ -58,7 +78,7 @@ class WeightedLayer {
|
|||||||
/**
|
/**
|
||||||
* @brief Initialize the biases
|
* @brief Initialize the biases
|
||||||
*/
|
*/
|
||||||
virtual void initializeBiases() = 0;
|
virtual void initializeBiases() = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Copy the weights and biases to the device
|
* @brief Copy the weights and biases to the device
|
||||||
@@ -3,9 +3,9 @@
|
|||||||
#include "cuda_helper.cuh"
|
#include "cuda_helper.cuh"
|
||||||
#include "activation_functions.cuh"
|
#include "activation_functions.cuh"
|
||||||
|
|
||||||
using namespace CUDANet;
|
using namespace CUDANet::Layers;
|
||||||
|
|
||||||
Layers::Activation::Activation(ActivationType activation, const unsigned int length)
|
Activation::Activation(ActivationType activation, const unsigned int length)
|
||||||
: activationType(activation), length(length) {
|
: activationType(activation), length(length) {
|
||||||
|
|
||||||
if (activationType == SOFTMAX) {
|
if (activationType == SOFTMAX) {
|
||||||
@@ -16,13 +16,13 @@ Layers::Activation::Activation(ActivationType activation, const unsigned int len
|
|||||||
gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
}
|
}
|
||||||
|
|
||||||
Layers::Activation::~Activation() {
|
Activation::~Activation() {
|
||||||
if (activationType == SOFTMAX) {
|
if (activationType == SOFTMAX) {
|
||||||
cudaFree(d_softmax_sum);
|
cudaFree(d_softmax_sum);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Layers::Activation::activate(float* __restrict__ d_input) {
|
void Activation::activate(float* __restrict__ d_input) {
|
||||||
|
|
||||||
switch (activationType) {
|
switch (activationType) {
|
||||||
case SIGMOID:
|
case SIGMOID:
|
||||||
|
|||||||
@@ -2,10 +2,10 @@
|
|||||||
#include "matmul.cuh"
|
#include "matmul.cuh"
|
||||||
#include "cuda_helper.cuh"
|
#include "cuda_helper.cuh"
|
||||||
|
|
||||||
using namespace CUDANet;
|
using namespace CUDANet::Layers;
|
||||||
|
|
||||||
|
|
||||||
Layers::Add::Add(int inputSize)
|
Add::Add(int inputSize)
|
||||||
: inputSize(inputSize) {
|
: inputSize(inputSize) {
|
||||||
|
|
||||||
d_output = nullptr;
|
d_output = nullptr;
|
||||||
@@ -15,12 +15,12 @@ Layers::Add::Add(int inputSize)
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Layers::Add::~Add() {
|
Add::~Add() {
|
||||||
cudaFree(d_output);
|
cudaFree(d_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
float* Layers::Add::forward(const float* d_inputA, const float* d_inputB) {
|
void Add::forward(const float* d_inputA, const float* d_inputB) {
|
||||||
|
|
||||||
Kernels::vec_vec_add<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::vec_vec_add<<<gridSize, BLOCK_SIZE>>>(
|
||||||
d_inputA, d_inputB, d_output, inputSize
|
d_inputA, d_inputB, d_output, inputSize
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
#include "concat.cuh"
|
#include "concat.cuh"
|
||||||
#include "cuda_helper.cuh"
|
#include "cuda_helper.cuh"
|
||||||
|
|
||||||
using namespace CUDANet;
|
using namespace CUDANet::Layers;
|
||||||
|
|
||||||
|
|
||||||
Layers::Concat::Concat(const unsigned int inputASize, const unsigned int inputBSize)
|
Concat::Concat(const unsigned int inputASize, const unsigned int inputBSize)
|
||||||
: inputASize(inputASize), inputBSize(inputBSize) {
|
: inputASize(inputASize), inputBSize(inputBSize) {
|
||||||
|
|
||||||
d_output = nullptr;
|
d_output = nullptr;
|
||||||
@@ -14,12 +14,12 @@ Layers::Concat::Concat(const unsigned int inputASize, const unsigned int inputBS
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Layers::Concat::~Concat() {
|
Concat::~Concat() {
|
||||||
cudaFree(d_output);
|
cudaFree(d_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
float* Layers::Concat::forward(const float* d_input_A, const float* d_input_B) {
|
float* Concat::forward(const float* d_input_A, const float* d_input_B) {
|
||||||
CUDA_CHECK(cudaMemcpy(
|
CUDA_CHECK(cudaMemcpy(
|
||||||
d_output, d_input_A, sizeof(float) * inputASize, cudaMemcpyDeviceToDevice
|
d_output, d_input_A, sizeof(float) * inputASize, cudaMemcpyDeviceToDevice
|
||||||
));
|
));
|
||||||
|
|||||||
@@ -7,16 +7,16 @@
|
|||||||
#include "cuda_helper.cuh"
|
#include "cuda_helper.cuh"
|
||||||
#include "matmul.cuh"
|
#include "matmul.cuh"
|
||||||
|
|
||||||
using namespace CUDANet;
|
using namespace CUDANet::Layers;
|
||||||
|
|
||||||
Layers::Conv2d::Conv2d(
|
Conv2d::Conv2d(
|
||||||
int inputSize,
|
int inputSize,
|
||||||
int inputChannels,
|
int inputChannels,
|
||||||
int kernelSize,
|
int kernelSize,
|
||||||
int stride,
|
int stride,
|
||||||
int numFilters,
|
int numFilters,
|
||||||
Layers::Padding padding,
|
Padding padding,
|
||||||
Layers::ActivationType activationType
|
ActivationType activationType
|
||||||
)
|
)
|
||||||
: inputSize(inputSize),
|
: inputSize(inputSize),
|
||||||
inputChannels(inputChannels),
|
inputChannels(inputChannels),
|
||||||
@@ -68,31 +68,31 @@ Layers::Conv2d::Conv2d(
|
|||||||
toCuda();
|
toCuda();
|
||||||
}
|
}
|
||||||
|
|
||||||
Layers::Conv2d::~Conv2d() {
|
Conv2d::~Conv2d() {
|
||||||
cudaFree(d_output);
|
cudaFree(d_output);
|
||||||
cudaFree(d_weights);
|
cudaFree(d_weights);
|
||||||
cudaFree(d_biases);
|
cudaFree(d_biases);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Layers::Conv2d::initializeWeights() {
|
void Conv2d::initializeWeights() {
|
||||||
std::fill(weights.begin(), weights.end(), 0.0f);
|
std::fill(weights.begin(), weights.end(), 0.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Layers::Conv2d::initializeBiases() {
|
void Conv2d::initializeBiases() {
|
||||||
std::fill(biases.begin(), biases.end(), 0.0f);
|
std::fill(biases.begin(), biases.end(), 0.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Layers::Conv2d::setWeights(const float* weights_input) {
|
void Conv2d::setWeights(const float* weights_input) {
|
||||||
std::copy(weights_input, weights_input + weights.size(), weights.begin());
|
std::copy(weights_input, weights_input + weights.size(), weights.begin());
|
||||||
toCuda();
|
toCuda();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Layers::Conv2d::setBiases(const float* biases_input) {
|
void Conv2d::setBiases(const float* biases_input) {
|
||||||
std::copy(biases_input, biases_input + biases.size(), biases.begin());
|
std::copy(biases_input, biases_input + biases.size(), biases.begin());
|
||||||
toCuda();
|
toCuda();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Layers::Conv2d::toCuda() {
|
void Conv2d::toCuda() {
|
||||||
CUDA_CHECK(cudaMemcpy(
|
CUDA_CHECK(cudaMemcpy(
|
||||||
d_weights, weights.data(),
|
d_weights, weights.data(),
|
||||||
sizeof(float) * kernelSize * kernelSize * inputChannels * numFilters,
|
sizeof(float) * kernelSize * kernelSize * inputChannels * numFilters,
|
||||||
@@ -106,7 +106,7 @@ void Layers::Conv2d::toCuda() {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
float* Layers::Conv2d::forward(const float* d_input) {
|
float* Conv2d::forward(const float* d_input) {
|
||||||
// Convolve
|
// Convolve
|
||||||
int THREADS_PER_BLOCK = outputSize * outputSize * numFilters;
|
int THREADS_PER_BLOCK = outputSize * outputSize * numFilters;
|
||||||
Kernels::convolution<<<1, THREADS_PER_BLOCK>>>(
|
Kernels::convolution<<<1, THREADS_PER_BLOCK>>>(
|
||||||
|
|||||||
@@ -10,19 +10,19 @@
|
|||||||
#include "dense.cuh"
|
#include "dense.cuh"
|
||||||
#include "matmul.cuh"
|
#include "matmul.cuh"
|
||||||
|
|
||||||
using namespace CUDANet;
|
using namespace CUDANet::Layers;
|
||||||
|
|
||||||
Layers::Dense::Dense(
|
Dense::Dense(
|
||||||
int inputSize,
|
int inputSize,
|
||||||
int outputSize,
|
int outputSize,
|
||||||
Layers::ActivationType activationType
|
ActivationType activationType
|
||||||
)
|
)
|
||||||
: inputSize(inputSize), outputSize(outputSize) {
|
: inputSize(inputSize), outputSize(outputSize) {
|
||||||
// Allocate memory for weights and biases
|
// Allocate memory for weights and biases
|
||||||
weights.resize(outputSize * inputSize);
|
weights.resize(outputSize * inputSize);
|
||||||
biases.resize(outputSize);
|
biases.resize(outputSize);
|
||||||
|
|
||||||
activation = Layers::Activation(activationType, outputSize);
|
activation = Activation(activationType, outputSize);
|
||||||
|
|
||||||
initializeWeights();
|
initializeWeights();
|
||||||
initializeBiases();
|
initializeBiases();
|
||||||
@@ -47,22 +47,22 @@ Layers::Dense::Dense(
|
|||||||
biasGridSize = (outputSize + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
biasGridSize = (outputSize + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
}
|
}
|
||||||
|
|
||||||
Layers::Dense::~Dense() {
|
Dense::~Dense() {
|
||||||
// Free GPU memory
|
// Free GPU memory
|
||||||
cudaFree(d_output);
|
cudaFree(d_output);
|
||||||
cudaFree(d_weights);
|
cudaFree(d_weights);
|
||||||
cudaFree(d_biases);
|
cudaFree(d_biases);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Layers::Dense::initializeWeights() {
|
void Dense::initializeWeights() {
|
||||||
std::fill(weights.begin(), weights.end(), 0.0f);
|
std::fill(weights.begin(), weights.end(), 0.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Layers::Dense::initializeBiases() {
|
void Dense::initializeBiases() {
|
||||||
std::fill(biases.begin(), biases.end(), 0.0f);
|
std::fill(biases.begin(), biases.end(), 0.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
float* Layers::Dense::forward(const float* d_input) {
|
float* Dense::forward(const float* d_input) {
|
||||||
Kernels::mat_vec_mul<<<forwardGridSize, BLOCK_SIZE>>>(
|
Kernels::mat_vec_mul<<<forwardGridSize, BLOCK_SIZE>>>(
|
||||||
d_weights, d_input, d_output, inputSize, outputSize
|
d_weights, d_input, d_output, inputSize, outputSize
|
||||||
);
|
);
|
||||||
@@ -78,7 +78,7 @@ float* Layers::Dense::forward(const float* d_input) {
|
|||||||
return d_output;
|
return d_output;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Layers::Dense::toCuda() {
|
void Dense::toCuda() {
|
||||||
CUDA_CHECK(cudaMemcpy(
|
CUDA_CHECK(cudaMemcpy(
|
||||||
d_weights, weights.data(), sizeof(float) * inputSize * outputSize,
|
d_weights, weights.data(), sizeof(float) * inputSize * outputSize,
|
||||||
cudaMemcpyHostToDevice
|
cudaMemcpyHostToDevice
|
||||||
@@ -89,12 +89,12 @@ void Layers::Dense::toCuda() {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Layers::Dense::setWeights(const float* weights_input) {
|
void Dense::setWeights(const float* weights_input) {
|
||||||
std::copy(weights_input, weights_input + weights.size(), weights.begin());
|
std::copy(weights_input, weights_input + weights.size(), weights.begin());
|
||||||
toCuda();
|
toCuda();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Layers::Dense::setBiases(const float* biases_input) {
|
void Dense::setBiases(const float* biases_input) {
|
||||||
std::copy(biases_input, biases_input + biases.size(), biases.begin());
|
std::copy(biases_input, biases_input + biases.size(), biases.begin());
|
||||||
toCuda();
|
toCuda();
|
||||||
}
|
}
|
||||||
@@ -1,14 +1,14 @@
|
|||||||
#include "cuda_helper.cuh"
|
#include "cuda_helper.cuh"
|
||||||
#include "input.cuh"
|
#include "input.cuh"
|
||||||
|
|
||||||
using namespace CUDANet;
|
using namespace CUDANet::Layers;
|
||||||
|
|
||||||
Layers::Input::Input(int inputSize) : inputSize(inputSize) {
|
Input::Input(int inputSize) : inputSize(inputSize) {
|
||||||
d_output = nullptr;
|
d_output = nullptr;
|
||||||
CUDA_CHECK(cudaMalloc((void**)&d_output, sizeof(float) * inputSize));
|
CUDA_CHECK(cudaMalloc((void**)&d_output, sizeof(float) * inputSize));
|
||||||
}
|
}
|
||||||
|
|
||||||
Layers::Input::~Input() {
|
Input::~Input() {
|
||||||
cudaFree(d_output);
|
cudaFree(d_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -19,7 +19,7 @@ Args
|
|||||||
const float* input Host pointer to input data
|
const float* input Host pointer to input data
|
||||||
float* d_output Device pointer to input data copied to device
|
float* d_output Device pointer to input data copied to device
|
||||||
*/
|
*/
|
||||||
float* Layers::Input::forward(const float* input) {
|
float* Input::forward(const float* input) {
|
||||||
CUDA_CHECK(cudaMemcpy(
|
CUDA_CHECK(cudaMemcpy(
|
||||||
d_output, input, sizeof(float) * inputSize, cudaMemcpyHostToDevice
|
d_output, input, sizeof(float) * inputSize, cudaMemcpyHostToDevice
|
||||||
));
|
));
|
||||||
|
|||||||
Reference in New Issue
Block a user