mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-07 02:04:26 +00:00
Move cuda source to separate directory
This commit is contained in:
30
src/cuda/kernels/activation_functions.cu
Normal file
30
src/cuda/kernels/activation_functions.cu
Normal file
@@ -0,0 +1,30 @@
|
||||
#include "activation_functions.cuh"
|
||||
#include "cuda_helper.cuh"
|
||||
|
||||
using namespace CUDANet;
|
||||
|
||||
__global__ void Kernels::sigmoid(
|
||||
const float* __restrict__ src,
|
||||
float* __restrict__ dst,
|
||||
const unsigned int len
|
||||
) {
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
for (int i = tid; i < len; i += stride) {
|
||||
dst[i] = 1.0 / (1.0 + exp(-src[i]));
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Kernels::relu(
|
||||
const float* __restrict__ src,
|
||||
float* __restrict__ dst,
|
||||
const unsigned int len
|
||||
) {
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
for (int i = tid; i < len; i += stride) {
|
||||
dst[i] = src[i] < 0.0 ? 0.0 : src[i];
|
||||
}
|
||||
}
|
||||
60
src/cuda/kernels/convolution.cu
Normal file
60
src/cuda/kernels/convolution.cu
Normal file
@@ -0,0 +1,60 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "convolution.cuh"
|
||||
|
||||
using namespace CUDANet;
|
||||
|
||||
__global__ void Kernels::convolution(
|
||||
const float* __restrict__ d_input,
|
||||
const float* __restrict__ d_kernel,
|
||||
const float* __restrict__ d_bias,
|
||||
float* __restrict__ d_output,
|
||||
const shape2d inputSize,
|
||||
const int nChannels,
|
||||
const shape2d paddingSize,
|
||||
const shape2d kernelSize,
|
||||
const shape2d stride,
|
||||
const int nFilters,
|
||||
const shape2d outputSize
|
||||
) {
|
||||
int j = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
int f = blockDim.z * blockIdx.z + threadIdx.z;
|
||||
|
||||
if (i >= outputSize.first || j >= outputSize.second || f >= nFilters) {
|
||||
return;
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
// Iterate over kernel and input matrix
|
||||
for (int c = 0; c < nChannels; c++) {
|
||||
for (int k = 0; k < kernelSize.first; k++) {
|
||||
for (int l = 0; l < kernelSize.second; l++) {
|
||||
// if i, j is in the padding region
|
||||
if (i * stride.first + k < paddingSize.first ||
|
||||
i * stride.first + k >=
|
||||
(inputSize.first + paddingSize.first) ||
|
||||
j * stride.second + l < paddingSize.second ||
|
||||
j * stride.second + l >=
|
||||
(inputSize.second + paddingSize.second)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int kernelIndex =
|
||||
f * kernelSize.first * kernelSize.second * nChannels +
|
||||
c * kernelSize.first * kernelSize.second +
|
||||
k * kernelSize.second + l;
|
||||
int inputIndex = c * inputSize.first * inputSize.second +
|
||||
(i * stride.first + k - paddingSize.first) *
|
||||
inputSize.second +
|
||||
(j * stride.second + l - paddingSize.second);
|
||||
|
||||
sum += d_kernel[kernelIndex] * d_input[inputIndex];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d_output[f * outputSize.first * outputSize.second + i * outputSize.second + j] =
|
||||
sum + d_bias[f];
|
||||
}
|
||||
211
src/cuda/kernels/matmul.cu
Normal file
211
src/cuda/kernels/matmul.cu
Normal file
@@ -0,0 +1,211 @@
|
||||
#include "cuda_helper.cuh"
|
||||
#include "matmul.cuh"
|
||||
|
||||
using namespace CUDANet;
|
||||
|
||||
__global__ void Kernels::mat_vec_mul(
|
||||
const float* __restrict__ d_matrix,
|
||||
const float* __restrict__ d_vector,
|
||||
float* __restrict__ d_output,
|
||||
const unsigned int w,
|
||||
const unsigned int h
|
||||
) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
if (tid < h) {
|
||||
float temp = 0.0f;
|
||||
|
||||
for (unsigned int j = 0; j < w; j++) {
|
||||
temp += d_matrix[tid * w + j] * d_vector[j];
|
||||
}
|
||||
|
||||
d_output[tid] = temp;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Kernels::vec_vec_add(
|
||||
const float* __restrict__ d_vector1,
|
||||
const float* __restrict__ d_vector2,
|
||||
float* __restrict__ d_output,
|
||||
const unsigned int w
|
||||
) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= w) {
|
||||
return;
|
||||
}
|
||||
d_output[tid] = d_vector1[tid] + d_vector2[tid];
|
||||
}
|
||||
|
||||
__global__ void Kernels::vec_vec_sub(
|
||||
const float* __restrict__ d_vector1,
|
||||
const float* __restrict__ d_vector2,
|
||||
float* __restrict__ d_output,
|
||||
const unsigned int w
|
||||
) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= w) {
|
||||
return;
|
||||
}
|
||||
d_output[tid] = d_vector1[tid] - d_vector2[tid];
|
||||
}
|
||||
|
||||
__global__ void Kernels::vec_vec_mul(
|
||||
const float* __restrict__ d_vector1,
|
||||
const float* __restrict__ d_vector2,
|
||||
float* __restrict__ d_output,
|
||||
const unsigned int w
|
||||
) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= w) {
|
||||
return;
|
||||
}
|
||||
d_output[tid] = d_vector1[tid] * d_vector2[tid];
|
||||
}
|
||||
|
||||
__global__ void Kernels::vec_scalar_sub(
|
||||
const float* __restrict__ d_src,
|
||||
float* __restrict__ d_out,
|
||||
const float* __restrict__ d_scalar,
|
||||
const unsigned int len
|
||||
) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= len) {
|
||||
return;
|
||||
}
|
||||
d_out[tid] = d_src[tid] - *d_scalar;
|
||||
}
|
||||
|
||||
__global__ void Kernels::vec_scalar_add(
|
||||
const float* __restrict__ d_src,
|
||||
float* __restrict__ d_out,
|
||||
const float* __restrict__ d_scalar,
|
||||
const unsigned int len
|
||||
) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= len) {
|
||||
return;
|
||||
}
|
||||
d_out[tid] = d_src[tid] + *d_scalar;
|
||||
}
|
||||
|
||||
__global__ void Kernels::vec_scalar_div(
|
||||
const float* __restrict__ d_src,
|
||||
float* __restrict__ d_out,
|
||||
const float* __restrict__ d_scalar,
|
||||
const unsigned int len
|
||||
) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= len) {
|
||||
return;
|
||||
}
|
||||
d_out[tid] = d_src[tid] / *d_scalar;
|
||||
}
|
||||
|
||||
__global__ void Kernels::vec_scalar_mul(
|
||||
const float* __restrict__ d_src,
|
||||
float* __restrict__ d_out,
|
||||
const float* __restrict__ d_scalar,
|
||||
const unsigned int len
|
||||
) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= len) {
|
||||
return;
|
||||
}
|
||||
d_out[tid] = d_src[tid] * *d_scalar;
|
||||
}
|
||||
|
||||
__global__ void Kernels::vec_exp(
|
||||
const float* __restrict__ src,
|
||||
float* __restrict__ dst,
|
||||
const unsigned int len
|
||||
) {
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
for (int i = tid; i < len; i += stride) {
|
||||
dst[i] = expf(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Kernels::vec_sqrt(
|
||||
const float* __restrict__ src,
|
||||
float* __restrict__ dst,
|
||||
const unsigned int len
|
||||
) {
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
for (int i = tid; i < len; i += stride) {
|
||||
dst[i] = sqrtf(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Kernels::vec_scale(
|
||||
const float* __restrict__ src,
|
||||
float* __restrict__ dst,
|
||||
const float* __restrict__ scale,
|
||||
const float* epsilon,
|
||||
const unsigned int len
|
||||
) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < len) {
|
||||
float inv_std = rsqrtf(*scale + *epsilon);
|
||||
dst[idx] = src[idx] * inv_std;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Kernels::max_reduce(
|
||||
const float* __restrict__ d_vector,
|
||||
float* __restrict__ d_output,
|
||||
const unsigned int len
|
||||
) {
|
||||
__shared__ float shared_max[BLOCK_SIZE];
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (i < len) {
|
||||
shared_max[threadIdx.x] = d_vector[i];
|
||||
} else {
|
||||
shared_max[threadIdx.x] = -INFINITY;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||
if (threadIdx.x < s) {
|
||||
shared_max[threadIdx.x] = fmaxf(shared_max[threadIdx.x], shared_max[threadIdx.x + s]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
d_output[blockIdx.x] = shared_max[0];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Kernels::sum_reduce(
|
||||
const float* __restrict__ d_vector,
|
||||
float* __restrict__ d_output,
|
||||
const unsigned int len
|
||||
) {
|
||||
__shared__ float partial_sum[BLOCK_SIZE];
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (i < len) {
|
||||
partial_sum[threadIdx.x] = d_vector[i];
|
||||
} else {
|
||||
partial_sum[threadIdx.x] = 0.0f;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||
if (threadIdx.x < s) {
|
||||
partial_sum[threadIdx.x] += partial_sum[threadIdx.x + s];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
d_output[blockIdx.x] = partial_sum[0];
|
||||
}
|
||||
}
|
||||
85
src/cuda/kernels/pooling.cu
Normal file
85
src/cuda/kernels/pooling.cu
Normal file
@@ -0,0 +1,85 @@
|
||||
#include "cuda_helper.cuh"
|
||||
#include "layer.cuh"
|
||||
#include "pooling.cuh"
|
||||
|
||||
using namespace CUDANet;
|
||||
|
||||
__global__ void Kernels::max_pooling(
|
||||
const float* __restrict__ d_input,
|
||||
float* __restrict__ d_output,
|
||||
const shape2d inputSize,
|
||||
const shape2d outputSize,
|
||||
const int nChannels,
|
||||
const shape2d poolingSize,
|
||||
const shape2d stride,
|
||||
const shape2d padding
|
||||
) {
|
||||
int j = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
int c = blockDim.z * blockIdx.z + threadIdx.z;
|
||||
|
||||
if (i >= outputSize.first || j >= outputSize.second || c >= nChannels) {
|
||||
return;
|
||||
}
|
||||
|
||||
float max = 0.0f;
|
||||
|
||||
for (int k = 0; k < poolingSize.first; k++) {
|
||||
for (int l = 0; l < poolingSize.second; l++) {
|
||||
int inputRow = i * stride.first + k - padding.first;
|
||||
int inputCol = j * stride.second + l - padding.second;
|
||||
|
||||
if (inputRow >= 0 && inputRow < inputSize.first && inputCol >= 0 &&
|
||||
inputCol < inputSize.second) {
|
||||
int inputIndex = c * inputSize.first * inputSize.second +
|
||||
inputRow * inputSize.second + inputCol;
|
||||
if (d_input[inputIndex] > max) {
|
||||
max = d_input[inputIndex];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d_output
|
||||
[c * outputSize.first * outputSize.second + i * outputSize.second + j] =
|
||||
max;
|
||||
}
|
||||
|
||||
__global__ void Kernels::avg_pooling(
|
||||
const float* __restrict__ d_input,
|
||||
float* __restrict__ d_output,
|
||||
const shape2d inputSize,
|
||||
const shape2d outputSize,
|
||||
const int nChannels,
|
||||
const shape2d poolingSize,
|
||||
const shape2d stride,
|
||||
const shape2d padding
|
||||
) {
|
||||
int j = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
int c = blockDim.z * blockIdx.z + threadIdx.z;
|
||||
|
||||
if (i >= outputSize.first || j >= outputSize.second || c >= nChannels) {
|
||||
return;
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int k = 0; k < poolingSize.first; k++) {
|
||||
for (int l = 0; l < poolingSize.second; l++) {
|
||||
int inputRow = i * stride.first + k - padding.first;
|
||||
int inputCol = j * stride.second + l - padding.second;
|
||||
|
||||
if (inputRow >= 0 && inputRow < inputSize.first && inputCol >= 0 &&
|
||||
inputCol < inputSize.second) {
|
||||
int inputIndex = c * inputSize.first * inputSize.second +
|
||||
inputRow * inputSize.second + inputCol;
|
||||
sum += d_input[inputIndex];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d_output
|
||||
[c * outputSize.first * outputSize.second + i * outputSize.second + j] =
|
||||
sum / (poolingSize.first * poolingSize.second);
|
||||
}
|
||||
79
src/cuda/layers/activation.cu
Normal file
79
src/cuda/layers/activation.cu
Normal file
@@ -0,0 +1,79 @@
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "activation.cuh"
|
||||
#include "activation_functions.cuh"
|
||||
#include "cuda_helper.cuh"
|
||||
#include "matmul.cuh"
|
||||
#include "vector.cuh"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
Activation::Activation(ActivationType activation, const int length)
|
||||
: activationType(activation), length(length) {
|
||||
if (activationType == SOFTMAX) {
|
||||
d_softmax_sum = nullptr;
|
||||
CUDA_CHECK(cudaMalloc((void**)&d_softmax_sum, sizeof(float) * length));
|
||||
|
||||
d_max = nullptr;
|
||||
CUDA_CHECK(cudaMalloc((void**)&d_max, sizeof(float) * length));
|
||||
}
|
||||
|
||||
gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
}
|
||||
|
||||
Activation::~Activation() {
|
||||
if (activationType == SOFTMAX) {
|
||||
CUDA_CHECK(cudaFree(d_softmax_sum));
|
||||
CUDA_CHECK(cudaFree(d_max));
|
||||
}
|
||||
}
|
||||
|
||||
void Activation::activate(float* d_input) {
|
||||
|
||||
// float sum = 0.0f;
|
||||
|
||||
switch (activationType) {
|
||||
case SIGMOID:
|
||||
Kernels::sigmoid<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_input, d_input, length
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
break;
|
||||
|
||||
case RELU:
|
||||
Kernels::relu<<<gridSize, BLOCK_SIZE>>>(d_input, d_input, length);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
break;
|
||||
case SOFTMAX:
|
||||
|
||||
// Find max value
|
||||
Utils::max(d_input, d_max, length);
|
||||
|
||||
// Subtract max value to improve numerical stability
|
||||
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_input, d_input, &d_max[0], length
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Compute exponentials
|
||||
Kernels::vec_exp<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_input, d_input, length
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Find sum
|
||||
Utils::sum(d_input, d_softmax_sum, length);
|
||||
|
||||
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_input, d_input, &d_softmax_sum[0], length
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
31
src/cuda/layers/add.cu
Normal file
31
src/cuda/layers/add.cu
Normal file
@@ -0,0 +1,31 @@
|
||||
#include "add.cuh"
|
||||
#include "matmul.cuh"
|
||||
#include "cuda_helper.cuh"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
|
||||
Add::Add(int inputSize)
|
||||
: inputSize(inputSize) {
|
||||
|
||||
d_output = nullptr;
|
||||
CUDA_CHECK(cudaMalloc((void**)&d_output, sizeof(float) * inputSize));
|
||||
|
||||
gridSize = (inputSize + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
}
|
||||
|
||||
|
||||
Add::~Add() {
|
||||
cudaFree(d_output);
|
||||
}
|
||||
|
||||
|
||||
void Add::forward(const float* d_inputA, const float* d_inputB) {
|
||||
|
||||
Kernels::vec_vec_add<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_inputA, d_inputB, d_output, inputSize
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
}
|
||||
91
src/cuda/layers/avg_pooling.cu
Normal file
91
src/cuda/layers/avg_pooling.cu
Normal file
@@ -0,0 +1,91 @@
|
||||
#include "avg_pooling.cuh"
|
||||
#include "cuda_helper.cuh"
|
||||
#include "pooling.cuh"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
AvgPooling2d::AvgPooling2d(
|
||||
shape2d inputSize,
|
||||
int nChannels,
|
||||
shape2d poolingSize,
|
||||
shape2d stride,
|
||||
shape2d padding,
|
||||
ActivationType activationType
|
||||
)
|
||||
: inputSize(inputSize),
|
||||
nChannels(nChannels),
|
||||
poolingSize(poolingSize),
|
||||
stride(stride),
|
||||
padding(padding) {
|
||||
outputSize = {
|
||||
(inputSize.first + 2 * padding.first - poolingSize.first) / stride.first + 1,
|
||||
(inputSize.second + 2 * padding.second - poolingSize.second) / stride.second + 1
|
||||
};
|
||||
|
||||
activation = new Activation(
|
||||
activationType, outputSize.first * outputSize.second * nChannels
|
||||
);
|
||||
|
||||
d_output = nullptr;
|
||||
CUDA_CHECK(cudaMalloc(
|
||||
(void**)&d_output,
|
||||
sizeof(float) * outputSize.first * outputSize.second * nChannels
|
||||
));
|
||||
}
|
||||
|
||||
AvgPooling2d::~AvgPooling2d() {
|
||||
cudaFree(d_output);
|
||||
delete activation;
|
||||
}
|
||||
|
||||
float* AvgPooling2d::forward(const float* d_input) {
|
||||
dim3 block(8, 8, 8);
|
||||
dim3 grid(
|
||||
(outputSize.first + block.x - 1) / block.x,
|
||||
(outputSize.second + block.y - 1) / block.y,
|
||||
(nChannels + block.z - 1) / block.z
|
||||
);
|
||||
|
||||
Kernels::avg_pooling<<<grid, block>>>(
|
||||
d_input, d_output, inputSize, outputSize, nChannels, poolingSize,
|
||||
stride, padding
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
activation->activate(d_output);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return d_output;
|
||||
}
|
||||
|
||||
int AvgPooling2d::getOutputSize() {
|
||||
return outputSize.first * outputSize.second * nChannels;
|
||||
}
|
||||
|
||||
int AvgPooling2d::getInputSize() {
|
||||
return inputSize.first * inputSize.second * nChannels;
|
||||
}
|
||||
|
||||
shape2d AvgPooling2d::getOutputDims() {
|
||||
return outputSize;
|
||||
}
|
||||
|
||||
AdaptiveAvgPooling2d::AdaptiveAvgPooling2d(shape2d inputShape, int nChannels, shape2d outputShape, ActivationType activationType)
|
||||
: AvgPooling2d(inputShape, nChannels, {1, 1}, {1, 1}, {0, 0}, activationType) {
|
||||
|
||||
stride = {inputShape.first / outputShape.first, inputShape.second / outputShape.second};
|
||||
poolingSize = {
|
||||
inputShape.first - (outputShape.first - 1) * stride.first,
|
||||
inputShape.second - (outputShape.second - 1) * stride.second
|
||||
};
|
||||
padding = {
|
||||
(poolingSize.first - 1) / 2,
|
||||
(poolingSize.second - 1) / 2
|
||||
};
|
||||
outputSize = outputShape;
|
||||
|
||||
activation = new Activation(activationType, outputSize.first * outputSize.second * nChannels);
|
||||
|
||||
cudaFree(d_output);
|
||||
cudaMalloc((void**)&d_output, sizeof(float) * outputSize.first * outputSize.second * nChannels);
|
||||
}
|
||||
212
src/cuda/layers/batch_norm.cu
Normal file
212
src/cuda/layers/batch_norm.cu
Normal file
@@ -0,0 +1,212 @@
|
||||
#include <vector>
|
||||
|
||||
#include "activation.cuh"
|
||||
#include "batch_norm.cuh"
|
||||
#include "cuda_helper.cuh"
|
||||
#include "layer.cuh"
|
||||
#include "matmul.cuh"
|
||||
#include "vector.cuh"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
BatchNorm2d::BatchNorm2d(
|
||||
shape2d inputSize,
|
||||
int inputChannels,
|
||||
float epsilon,
|
||||
ActivationType activationType
|
||||
)
|
||||
: inputSize(inputSize), inputChannels(inputChannels) {
|
||||
activation = new Activation(
|
||||
activationType, inputSize.first * inputSize.second * inputChannels
|
||||
);
|
||||
|
||||
d_output = nullptr;
|
||||
CUDA_CHECK(cudaMalloc(
|
||||
(void **)&d_output,
|
||||
sizeof(float) * inputSize.first * inputSize.second * inputChannels
|
||||
));
|
||||
|
||||
d_running_mean = nullptr;
|
||||
CUDA_CHECK(cudaMalloc(
|
||||
(void **)&d_running_mean, sizeof(float) * inputChannels
|
||||
));
|
||||
|
||||
d_running_var = nullptr;
|
||||
CUDA_CHECK(cudaMalloc(
|
||||
(void **)&d_running_var, sizeof(float) * inputChannels
|
||||
));
|
||||
|
||||
d_weights = nullptr;
|
||||
CUDA_CHECK(cudaMalloc((void **)&d_weights, sizeof(float) * inputChannels));
|
||||
|
||||
d_biases = nullptr;
|
||||
CUDA_CHECK(cudaMalloc((void **)&d_biases, sizeof(float) * inputChannels));
|
||||
|
||||
d_length = nullptr;
|
||||
float length = (float)inputSize.first * inputSize.second;
|
||||
CUDA_CHECK(cudaMalloc((void **)&d_length, sizeof(float)));
|
||||
CUDA_CHECK(
|
||||
cudaMemcpy(d_length, &length, sizeof(float), cudaMemcpyHostToDevice)
|
||||
);
|
||||
|
||||
d_epsilon = nullptr;
|
||||
CUDA_CHECK(cudaMalloc((void **)&d_epsilon, sizeof(float)));
|
||||
CUDA_CHECK(
|
||||
cudaMemcpy(d_epsilon, &epsilon, sizeof(float), cudaMemcpyHostToDevice)
|
||||
);
|
||||
|
||||
weights.resize(inputChannels);
|
||||
biases.resize(inputChannels);
|
||||
|
||||
running_mean.resize(inputChannels);
|
||||
running_var.resize(inputChannels);
|
||||
|
||||
initializeWeights();
|
||||
initializeBiases();
|
||||
initializeRunningMean();
|
||||
initializeRunningVar();
|
||||
|
||||
toCuda();
|
||||
|
||||
gridSize =
|
||||
(inputSize.first * inputSize.second + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
}
|
||||
|
||||
BatchNorm2d::~BatchNorm2d() {
|
||||
cudaFree(d_output);
|
||||
cudaFree(d_running_mean);
|
||||
cudaFree(d_running_var);
|
||||
cudaFree(d_weights);
|
||||
cudaFree(d_biases);
|
||||
cudaFree(d_length);
|
||||
cudaFree(d_epsilon);
|
||||
}
|
||||
|
||||
void BatchNorm2d::initializeWeights() {
|
||||
std::fill(weights.begin(), weights.end(), 1.0f);
|
||||
}
|
||||
|
||||
void BatchNorm2d::initializeBiases() {
|
||||
std::fill(biases.begin(), biases.end(), 0.0f);
|
||||
}
|
||||
|
||||
void BatchNorm2d::initializeRunningMean() {
|
||||
std::fill(running_mean.begin(), running_mean.end(), 0.0f);
|
||||
}
|
||||
|
||||
void BatchNorm2d::initializeRunningVar() {
|
||||
std::fill(running_var.begin(), running_var.end(), 1.0f);
|
||||
}
|
||||
|
||||
void BatchNorm2d::setWeights(const float *weights_input) {
|
||||
std::copy(weights_input, weights_input + weights.size(), weights.begin());
|
||||
toCuda();
|
||||
}
|
||||
|
||||
std::vector<float> BatchNorm2d::getWeights() {
|
||||
return weights;
|
||||
}
|
||||
|
||||
void BatchNorm2d::setBiases(const float *biases_input) {
|
||||
std::copy(biases_input, biases_input + biases.size(), biases.begin());
|
||||
toCuda();
|
||||
}
|
||||
|
||||
std::vector<float> BatchNorm2d::getBiases() {
|
||||
return biases;
|
||||
}
|
||||
|
||||
void BatchNorm2d::setRunningMean(const float* running_mean_input) {
|
||||
std::copy(running_mean_input, running_mean_input + inputChannels, running_mean.begin());
|
||||
toCuda();
|
||||
}
|
||||
|
||||
std::vector<float> BatchNorm2d::getRunningMean() {
|
||||
return running_mean;
|
||||
}
|
||||
|
||||
void BatchNorm2d::setRunningVar(const float* running_var_input) {
|
||||
std::copy(running_var_input, running_var_input + inputChannels, running_var.begin());
|
||||
toCuda();
|
||||
}
|
||||
|
||||
std::vector<float> BatchNorm2d::getRunningVar() {
|
||||
return running_var;
|
||||
}
|
||||
|
||||
void BatchNorm2d::toCuda() {
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_weights, weights.data(), sizeof(float) * inputChannels,
|
||||
cudaMemcpyHostToDevice
|
||||
));
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_biases, biases.data(), sizeof(float) * inputChannels,
|
||||
cudaMemcpyHostToDevice
|
||||
));
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_running_mean, running_mean.data(), sizeof(float) * inputChannels,
|
||||
cudaMemcpyHostToDevice
|
||||
));
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_running_var, running_var.data(), sizeof(float) * inputChannels,
|
||||
cudaMemcpyHostToDevice
|
||||
));
|
||||
}
|
||||
|
||||
int BatchNorm2d::getInputSize() {
|
||||
return inputSize.first * inputSize.second * inputChannels;
|
||||
}
|
||||
|
||||
int BatchNorm2d::getOutputSize() {
|
||||
return inputSize.first * inputSize.second * inputChannels;
|
||||
}
|
||||
|
||||
shape2d BatchNorm2d::getOutputDims() {
|
||||
return inputSize;
|
||||
}
|
||||
|
||||
float *BatchNorm2d::forward(const float *d_input) {
|
||||
// Compute per-channel batch normalization
|
||||
for (int i = 0; i < inputChannels; i++) {
|
||||
|
||||
// Subtract mean from input
|
||||
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_input + i * inputSize.first * inputSize.second,
|
||||
d_output + i * inputSize.first * inputSize.second,
|
||||
&d_running_mean[i], inputSize.first * inputSize.second
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Divide by sqrt(running_var + epsilon)
|
||||
Kernels::vec_scale<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_output + i * inputSize.first * inputSize.second,
|
||||
d_output + i * inputSize.first * inputSize.second,
|
||||
&d_running_var[i],
|
||||
d_epsilon,
|
||||
inputSize.first * inputSize.second
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Multiply by weights
|
||||
Kernels::vec_scalar_mul<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_output + i * inputSize.first * inputSize.second,
|
||||
d_output + i * inputSize.first * inputSize.second,
|
||||
&d_weights[i],
|
||||
inputSize.first * inputSize.second
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Add biases
|
||||
Kernels::vec_scalar_add<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_output + i * inputSize.first * inputSize.second,
|
||||
d_output + i * inputSize.first * inputSize.second,
|
||||
&d_biases[i],
|
||||
inputSize.first * inputSize.second
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
activation->activate(d_output);
|
||||
|
||||
return d_output;
|
||||
}
|
||||
37
src/cuda/layers/concat.cu
Normal file
37
src/cuda/layers/concat.cu
Normal file
@@ -0,0 +1,37 @@
|
||||
#include "concat.cuh"
|
||||
#include "cuda_helper.cuh"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
|
||||
Concat::Concat(const int inputASize, const int inputBSize)
|
||||
: inputASize(inputASize), inputBSize(inputBSize) {
|
||||
|
||||
d_output = nullptr;
|
||||
CUDA_CHECK(cudaMalloc(
|
||||
(void**)&d_output, sizeof(float) * (inputASize + inputBSize)
|
||||
));
|
||||
}
|
||||
|
||||
Concat::~Concat() {
|
||||
cudaFree(d_output);
|
||||
}
|
||||
|
||||
float* Concat::forward(const float* d_input_A, const float* d_input_B) {
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_output, d_input_A, sizeof(float) * inputASize, cudaMemcpyDeviceToDevice
|
||||
));
|
||||
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_output + inputASize, d_input_B,
|
||||
sizeof(float) * inputBSize, cudaMemcpyDeviceToDevice
|
||||
));
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return d_output;
|
||||
}
|
||||
|
||||
int Concat::getOutputSize() {
|
||||
return inputASize + inputBSize;
|
||||
};
|
||||
144
src/cuda/layers/conv2d.cu
Normal file
144
src/cuda/layers/conv2d.cu
Normal file
@@ -0,0 +1,144 @@
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "activation.cuh"
|
||||
#include "conv2d.cuh"
|
||||
#include "convolution.cuh"
|
||||
#include "cuda_helper.cuh"
|
||||
#include "layer.cuh"
|
||||
#include "matmul.cuh"
|
||||
#include "vector.cuh"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
Conv2d::Conv2d(
|
||||
shape2d inputSize,
|
||||
int inputChannels,
|
||||
shape2d kernelSize,
|
||||
shape2d stride,
|
||||
int numFilters,
|
||||
shape2d paddingSize,
|
||||
ActivationType activationType
|
||||
)
|
||||
: inputSize(inputSize),
|
||||
inputChannels(inputChannels),
|
||||
kernelSize(kernelSize),
|
||||
stride(stride),
|
||||
numFilters(numFilters),
|
||||
paddingSize(paddingSize) {
|
||||
|
||||
outputSize = {
|
||||
(inputSize.first - kernelSize.first + 2 * paddingSize.first) /
|
||||
stride.first + 1,
|
||||
(inputSize.second - kernelSize.second + 2 * paddingSize.second) /
|
||||
stride.second + 1
|
||||
};
|
||||
|
||||
activation =
|
||||
new Activation(activationType, outputSize.first * outputSize.second * numFilters);
|
||||
|
||||
d_output = nullptr;
|
||||
CUDA_CHECK(cudaMalloc(
|
||||
(void**)&d_output, sizeof(float) * outputSize.first * outputSize.second * numFilters
|
||||
));
|
||||
|
||||
weights.resize(kernelSize.first * kernelSize.second * inputChannels * numFilters);
|
||||
initializeWeights();
|
||||
|
||||
d_weights = nullptr;
|
||||
CUDA_CHECK(cudaMalloc(
|
||||
(void**)&d_weights,
|
||||
sizeof(float) * kernelSize.first * kernelSize.second * inputChannels * numFilters
|
||||
));
|
||||
|
||||
biases.resize(numFilters);
|
||||
initializeBiases();
|
||||
|
||||
d_biases = nullptr;
|
||||
CUDA_CHECK(cudaMalloc((void**)&d_biases, sizeof(float) * numFilters));
|
||||
|
||||
toCuda();
|
||||
}
|
||||
|
||||
Conv2d::~Conv2d() {
|
||||
cudaFree(d_output);
|
||||
cudaFree(d_weights);
|
||||
cudaFree(d_biases);
|
||||
delete activation;
|
||||
}
|
||||
|
||||
void Conv2d::initializeWeights() {
|
||||
std::fill(weights.begin(), weights.end(), 0.0f);
|
||||
}
|
||||
|
||||
void Conv2d::initializeBiases() {
|
||||
std::fill(biases.begin(), biases.end(), 0.0f);
|
||||
}
|
||||
|
||||
void Conv2d::setWeights(const float* weights_input) {
|
||||
std::copy(weights_input, weights_input + weights.size(), weights.begin());
|
||||
toCuda();
|
||||
}
|
||||
|
||||
std::vector<float> Conv2d::getWeights() {
|
||||
return weights;
|
||||
}
|
||||
|
||||
void Conv2d::setBiases(const float* biases_input) {
|
||||
std::copy(biases_input, biases_input + biases.size(), biases.begin());
|
||||
toCuda();
|
||||
}
|
||||
|
||||
std::vector<float> Conv2d::getBiases() {
|
||||
return biases;
|
||||
}
|
||||
|
||||
void Conv2d::toCuda() {
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_weights, weights.data(),
|
||||
sizeof(float) * kernelSize.first * kernelSize.second * inputChannels * numFilters,
|
||||
cudaMemcpyHostToDevice
|
||||
));
|
||||
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_biases, biases.data(), sizeof(float) * numFilters,
|
||||
cudaMemcpyHostToDevice
|
||||
));
|
||||
}
|
||||
|
||||
float* Conv2d::forward(const float* d_input) {
|
||||
// Convolve
|
||||
dim3 block(8, 8, 8);
|
||||
dim3 grid(
|
||||
(outputSize.first + block.x - 1) / block.x,
|
||||
(outputSize.second + block.y - 1) / block.y,
|
||||
(numFilters + block.z - 1) / block.z
|
||||
);
|
||||
|
||||
CUDANet::Utils::clear(d_output, outputSize.first * outputSize.second * numFilters);
|
||||
|
||||
Kernels::convolution<<<grid, block>>>(
|
||||
d_input, d_weights, d_biases, d_output, inputSize, inputChannels,
|
||||
paddingSize, kernelSize, stride, numFilters, outputSize
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Apply activation
|
||||
activation->activate(d_output);
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return d_output;
|
||||
}
|
||||
|
||||
int Conv2d::getOutputSize() {
|
||||
return outputSize.first * outputSize.second * numFilters;
|
||||
}
|
||||
|
||||
int Conv2d::getInputSize() {
|
||||
return inputSize.first * inputSize.second * inputChannels;
|
||||
}
|
||||
|
||||
shape2d Conv2d::getOutputDims() {
|
||||
return outputSize;
|
||||
}
|
||||
119
src/cuda/layers/dense.cu
Normal file
119
src/cuda/layers/dense.cu
Normal file
@@ -0,0 +1,119 @@
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
|
||||
#include "vector.cuh"
|
||||
#include "activation.cuh"
|
||||
#include "cuda_helper.cuh"
|
||||
#include "dense.cuh"
|
||||
#include "matmul.cuh"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
Dense::Dense(
|
||||
int inputSize,
|
||||
int outputSize,
|
||||
ActivationType activationType
|
||||
)
|
||||
: inputSize(inputSize), outputSize(outputSize) {
|
||||
// Allocate memory for weights and biases
|
||||
weights.resize(outputSize * inputSize);
|
||||
biases.resize(outputSize);
|
||||
|
||||
initializeWeights();
|
||||
initializeBiases();
|
||||
|
||||
d_output = nullptr;
|
||||
|
||||
CUDA_CHECK(cudaMalloc((void**)&d_output, sizeof(float) * outputSize));
|
||||
|
||||
d_weights = nullptr;
|
||||
d_biases = nullptr;
|
||||
|
||||
// Allocate GPU memory for weights and biases
|
||||
CUDA_CHECK(
|
||||
cudaMalloc((void**)&d_weights, sizeof(float) * inputSize * outputSize)
|
||||
);
|
||||
CUDA_CHECK(cudaMalloc((void**)&d_biases, sizeof(float) * outputSize));
|
||||
toCuda();
|
||||
|
||||
// Calculate block and grid sizes
|
||||
forwardGridSize =
|
||||
(std::max(inputSize, outputSize) + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
biasGridSize = (outputSize + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
activation = new Activation(activationType, outputSize);
|
||||
}
|
||||
|
||||
Dense::~Dense() {
|
||||
cudaFree(d_output);
|
||||
cudaFree(d_weights);
|
||||
cudaFree(d_biases);
|
||||
delete activation;
|
||||
}
|
||||
|
||||
void Dense::initializeWeights() {
|
||||
std::fill(weights.begin(), weights.end(), 0.0f);
|
||||
}
|
||||
|
||||
void Dense::initializeBiases() {
|
||||
std::fill(biases.begin(), biases.end(), 0.0f);
|
||||
}
|
||||
|
||||
float* Dense::forward(const float* d_input) {
|
||||
|
||||
Kernels::mat_vec_mul<<<forwardGridSize, BLOCK_SIZE>>>(
|
||||
d_weights, d_input, d_output, inputSize, outputSize
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
Kernels::vec_vec_add<<<biasGridSize, BLOCK_SIZE>>>(
|
||||
d_biases, d_output, d_output, outputSize
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
activation->activate(d_output);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return d_output;
|
||||
}
|
||||
|
||||
void Dense::toCuda() {
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_weights, weights.data(), sizeof(float) * inputSize * outputSize,
|
||||
cudaMemcpyHostToDevice
|
||||
));
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_biases, biases.data(), sizeof(float) * outputSize,
|
||||
cudaMemcpyHostToDevice
|
||||
));
|
||||
}
|
||||
|
||||
void Dense::setWeights(const float* weights_input) {
|
||||
std::copy(weights_input, weights_input + weights.size(), weights.begin());
|
||||
toCuda();
|
||||
}
|
||||
|
||||
std::vector<float> Dense::getWeights() {
|
||||
return weights;
|
||||
}
|
||||
|
||||
void Dense::setBiases(const float* biases_input) {
|
||||
std::copy(biases_input, biases_input + biases.size(), biases.begin());
|
||||
toCuda();
|
||||
}
|
||||
|
||||
std::vector<float> Dense::getBiases() {
|
||||
return biases;
|
||||
}
|
||||
|
||||
int Dense::getOutputSize() {
|
||||
return outputSize;
|
||||
}
|
||||
|
||||
int Dense::getInputSize() {
|
||||
return inputSize;
|
||||
}
|
||||
31
src/cuda/layers/input.cu
Normal file
31
src/cuda/layers/input.cu
Normal file
@@ -0,0 +1,31 @@
|
||||
#include "cuda_helper.cuh"
|
||||
#include "input.cuh"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
Input::Input(int inputSize) : inputSize(inputSize) {
|
||||
d_output = nullptr;
|
||||
CUDA_CHECK(cudaMalloc((void**)&d_output, sizeof(float) * inputSize));
|
||||
}
|
||||
|
||||
Input::~Input() {
|
||||
cudaFree(d_output);
|
||||
}
|
||||
|
||||
float* Input::forward(const float* input) {
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_output, input, sizeof(float) * inputSize, cudaMemcpyHostToDevice
|
||||
));
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return d_output;
|
||||
}
|
||||
|
||||
int Input::getOutputSize() {
|
||||
return inputSize;
|
||||
}
|
||||
|
||||
|
||||
int Input::getInputSize() {
|
||||
return inputSize;
|
||||
}
|
||||
75
src/cuda/layers/max_pooling.cu
Normal file
75
src/cuda/layers/max_pooling.cu
Normal file
@@ -0,0 +1,75 @@
|
||||
#include "cuda_helper.cuh"
|
||||
#include "max_pooling.cuh"
|
||||
#include "pooling.cuh"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
MaxPooling2d::MaxPooling2d(
|
||||
shape2d inputSize,
|
||||
int nChannels,
|
||||
shape2d poolingSize,
|
||||
shape2d stride,
|
||||
shape2d padding,
|
||||
ActivationType activationType
|
||||
)
|
||||
: inputSize(inputSize),
|
||||
nChannels(nChannels),
|
||||
poolingSize(poolingSize),
|
||||
stride(stride),
|
||||
padding(padding) {
|
||||
outputSize = {
|
||||
(inputSize.first + 2 * padding.first - poolingSize.first) /
|
||||
stride.first +
|
||||
1,
|
||||
(inputSize.second + 2 * padding.second - poolingSize.second) /
|
||||
stride.second +
|
||||
1
|
||||
};
|
||||
|
||||
activation = new Activation(
|
||||
activationType, outputSize.first * outputSize.second * nChannels
|
||||
);
|
||||
|
||||
d_output = nullptr;
|
||||
CUDA_CHECK(cudaMalloc(
|
||||
(void**)&d_output,
|
||||
sizeof(float) * outputSize.first * outputSize.second * nChannels
|
||||
));
|
||||
}
|
||||
|
||||
MaxPooling2d::~MaxPooling2d() {
|
||||
cudaFree(d_output);
|
||||
delete activation;
|
||||
}
|
||||
|
||||
float* MaxPooling2d::forward(const float* d_input) {
|
||||
dim3 block(8, 8, 8);
|
||||
dim3 grid(
|
||||
(outputSize.first + block.x - 1) / block.x,
|
||||
(outputSize.second + block.y - 1) / block.y,
|
||||
(nChannels + block.z - 1) / block.z
|
||||
);
|
||||
|
||||
Kernels::max_pooling<<<grid, block>>>(
|
||||
d_input, d_output, inputSize, outputSize, nChannels, poolingSize,
|
||||
stride, padding
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
activation->activate(d_output);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return d_output;
|
||||
}
|
||||
|
||||
int MaxPooling2d::getOutputSize() {
|
||||
return outputSize.first * outputSize.second * nChannels;
|
||||
}
|
||||
|
||||
int MaxPooling2d::getInputSize() {
|
||||
return inputSize.first * inputSize.second * nChannels;
|
||||
}
|
||||
|
||||
shape2d MaxPooling2d::getOutputDims() {
|
||||
return outputSize;
|
||||
}
|
||||
32
src/cuda/layers/output.cu
Normal file
32
src/cuda/layers/output.cu
Normal file
@@ -0,0 +1,32 @@
|
||||
#include "output.cuh"
|
||||
|
||||
#include "cuda_helper.cuh"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
|
||||
Output::Output(int inputSize) : inputSize(inputSize) {
|
||||
h_output = (float*) malloc(sizeof(float) * inputSize);
|
||||
}
|
||||
|
||||
Output::~Output() {
|
||||
free(h_output);
|
||||
}
|
||||
|
||||
float* Output::forward(const float* input) {
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
h_output, input, sizeof(float) * inputSize, cudaMemcpyDeviceToHost
|
||||
));
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return h_output;
|
||||
}
|
||||
|
||||
int Output::getOutputSize() {
|
||||
return inputSize;
|
||||
}
|
||||
|
||||
|
||||
int Output::getInputSize() {
|
||||
return inputSize;
|
||||
}
|
||||
26
src/cuda/utils/cuda_helper.cu
Normal file
26
src/cuda/utils/cuda_helper.cu
Normal file
@@ -0,0 +1,26 @@
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "cuda_helper.cuh"
|
||||
|
||||
cudaDeviceProp initializeCUDA() {
|
||||
int deviceCount;
|
||||
CUDA_CHECK(cudaGetDeviceCount(&deviceCount));
|
||||
|
||||
if (deviceCount == 0) {
|
||||
std::fprintf(stderr, "No CUDA devices found. Exiting.\n");
|
||||
std::exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
int device = 0;
|
||||
CUDA_CHECK(cudaSetDevice(device));
|
||||
|
||||
cudaDeviceProp deviceProp;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&deviceProp, device));
|
||||
|
||||
std::printf("Using CUDA device %d: %s\n", device, deviceProp.name);
|
||||
|
||||
return deviceProp;
|
||||
}
|
||||
107
src/cuda/utils/vector.cu
Normal file
107
src/cuda/utils/vector.cu
Normal file
@@ -0,0 +1,107 @@
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "vector.cuh"
|
||||
#include "matmul.cuh"
|
||||
#include "cuda_helper.cuh"
|
||||
|
||||
using namespace CUDANet;
|
||||
|
||||
void Utils::print_vec(const float* d_vec, const unsigned int length) {
|
||||
std::vector<float> h_vec(length);
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
h_vec.data(), d_vec, sizeof(float) * length, cudaMemcpyDeviceToHost
|
||||
));
|
||||
|
||||
for (int i = 0; i < length; ++i) {
|
||||
std::cout << h_vec[i] << ", ";
|
||||
}
|
||||
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
void Utils::clear(float* d_vec, const unsigned int length) {
|
||||
CUDA_CHECK(cudaMemset(d_vec, 0, sizeof(float) * length));
|
||||
}
|
||||
|
||||
void Utils::max(const float* d_vec, float* d_max, const unsigned int length) {
|
||||
|
||||
const int grid_size = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
Kernels::max_reduce<<<grid_size, BLOCK_SIZE>>>(d_vec, d_max, length);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
int remaining = grid_size;
|
||||
|
||||
while (remaining > 1) {
|
||||
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
CUDANet::Kernels::max_reduce<<<blocks_needed, BLOCK_SIZE>>>(d_max, d_max, remaining);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
remaining = blocks_needed;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void Utils::sum(const float* d_vec, float* d_sum, const unsigned int length) {
|
||||
|
||||
const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
CUDANet::Kernels::sum_reduce<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_vec, d_sum, length
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
int remaining = gridSize;
|
||||
while (remaining > 1) {
|
||||
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
CUDANet::Kernels::sum_reduce<<<blocks_needed, BLOCK_SIZE>>>(d_sum, d_sum, remaining);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
remaining = blocks_needed;
|
||||
}
|
||||
}
|
||||
|
||||
void Utils::mean(const float* d_vec, float* d_mean, float *d_length, int length) {
|
||||
Utils::sum(d_vec, d_mean, length);
|
||||
|
||||
const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_mean,
|
||||
d_mean,
|
||||
d_length,
|
||||
length
|
||||
);
|
||||
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
|
||||
void Utils::var(float* d_vec, float* d_var, float *d_length, const unsigned int length) {
|
||||
|
||||
const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
Kernels::vec_vec_mul<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_vec,
|
||||
d_vec,
|
||||
d_var,
|
||||
length
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Sum over all differences
|
||||
Utils::sum(
|
||||
d_var,
|
||||
d_var,
|
||||
length
|
||||
);
|
||||
|
||||
// Divide by difference sum / length -> variance
|
||||
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_var,
|
||||
d_var,
|
||||
d_length,
|
||||
length
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user