mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-07 18:24:26 +00:00
Restructure cuda backend
This commit is contained in:
30
src/backends/cuda/kernels/activation_functions.cu
Normal file
30
src/backends/cuda/kernels/activation_functions.cu
Normal file
@@ -0,0 +1,30 @@
|
||||
#include "activation_functions.cuh"
|
||||
#include "cuda_helper.cuh"
|
||||
|
||||
using namespace CUDANet;
|
||||
|
||||
__global__ void Kernels::sigmoid(
|
||||
const float* __restrict__ src,
|
||||
float* __restrict__ dst,
|
||||
const unsigned int len
|
||||
) {
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
for (int i = tid; i < len; i += stride) {
|
||||
dst[i] = 1.0 / (1.0 + exp(-src[i]));
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Kernels::relu(
|
||||
const float* __restrict__ src,
|
||||
float* __restrict__ dst,
|
||||
const unsigned int len
|
||||
) {
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
for (int i = tid; i < len; i += stride) {
|
||||
dst[i] = src[i] < 0.0 ? 0.0 : src[i];
|
||||
}
|
||||
}
|
||||
60
src/backends/cuda/kernels/convolution.cu
Normal file
60
src/backends/cuda/kernels/convolution.cu
Normal file
@@ -0,0 +1,60 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "convolution.cuh"
|
||||
|
||||
using namespace CUDANet;
|
||||
|
||||
__global__ void Kernels::convolution(
|
||||
const float* __restrict__ d_input,
|
||||
const float* __restrict__ d_kernel,
|
||||
const float* __restrict__ d_bias,
|
||||
float* __restrict__ d_output,
|
||||
const shape2d inputSize,
|
||||
const int nChannels,
|
||||
const shape2d paddingSize,
|
||||
const shape2d kernelSize,
|
||||
const shape2d stride,
|
||||
const int nFilters,
|
||||
const shape2d outputSize
|
||||
) {
|
||||
int j = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
int f = blockDim.z * blockIdx.z + threadIdx.z;
|
||||
|
||||
if (i >= outputSize.first || j >= outputSize.second || f >= nFilters) {
|
||||
return;
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
// Iterate over kernel and input matrix
|
||||
for (int c = 0; c < nChannels; c++) {
|
||||
for (int k = 0; k < kernelSize.first; k++) {
|
||||
for (int l = 0; l < kernelSize.second; l++) {
|
||||
// if i, j is in the padding region
|
||||
if (i * stride.first + k < paddingSize.first ||
|
||||
i * stride.first + k >=
|
||||
(inputSize.first + paddingSize.first) ||
|
||||
j * stride.second + l < paddingSize.second ||
|
||||
j * stride.second + l >=
|
||||
(inputSize.second + paddingSize.second)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int kernelIndex =
|
||||
f * kernelSize.first * kernelSize.second * nChannels +
|
||||
c * kernelSize.first * kernelSize.second +
|
||||
k * kernelSize.second + l;
|
||||
int inputIndex = c * inputSize.first * inputSize.second +
|
||||
(i * stride.first + k - paddingSize.first) *
|
||||
inputSize.second +
|
||||
(j * stride.second + l - paddingSize.second);
|
||||
|
||||
sum += d_kernel[kernelIndex] * d_input[inputIndex];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d_output[f * outputSize.first * outputSize.second + i * outputSize.second + j] =
|
||||
sum + d_bias[f];
|
||||
}
|
||||
211
src/backends/cuda/kernels/matmul.cu
Normal file
211
src/backends/cuda/kernels/matmul.cu
Normal file
@@ -0,0 +1,211 @@
|
||||
#include "cuda_helper.cuh"
|
||||
#include "matmul.cuh"
|
||||
|
||||
using namespace CUDANet;
|
||||
|
||||
__global__ void Kernels::mat_vec_mul(
|
||||
const float* __restrict__ d_matrix,
|
||||
const float* __restrict__ d_vector,
|
||||
float* __restrict__ d_output,
|
||||
const unsigned int w,
|
||||
const unsigned int h
|
||||
) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
if (tid < h) {
|
||||
float temp = 0.0f;
|
||||
|
||||
for (unsigned int j = 0; j < w; j++) {
|
||||
temp += d_matrix[tid * w + j] * d_vector[j];
|
||||
}
|
||||
|
||||
d_output[tid] = temp;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Kernels::vec_vec_add(
|
||||
const float* __restrict__ d_vector1,
|
||||
const float* __restrict__ d_vector2,
|
||||
float* __restrict__ d_output,
|
||||
const unsigned int w
|
||||
) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= w) {
|
||||
return;
|
||||
}
|
||||
d_output[tid] = d_vector1[tid] + d_vector2[tid];
|
||||
}
|
||||
|
||||
__global__ void Kernels::vec_vec_sub(
|
||||
const float* __restrict__ d_vector1,
|
||||
const float* __restrict__ d_vector2,
|
||||
float* __restrict__ d_output,
|
||||
const unsigned int w
|
||||
) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= w) {
|
||||
return;
|
||||
}
|
||||
d_output[tid] = d_vector1[tid] - d_vector2[tid];
|
||||
}
|
||||
|
||||
__global__ void Kernels::vec_vec_mul(
|
||||
const float* __restrict__ d_vector1,
|
||||
const float* __restrict__ d_vector2,
|
||||
float* __restrict__ d_output,
|
||||
const unsigned int w
|
||||
) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= w) {
|
||||
return;
|
||||
}
|
||||
d_output[tid] = d_vector1[tid] * d_vector2[tid];
|
||||
}
|
||||
|
||||
__global__ void Kernels::vec_scalar_sub(
|
||||
const float* __restrict__ d_src,
|
||||
float* __restrict__ d_out,
|
||||
const float* __restrict__ d_scalar,
|
||||
const unsigned int len
|
||||
) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= len) {
|
||||
return;
|
||||
}
|
||||
d_out[tid] = d_src[tid] - *d_scalar;
|
||||
}
|
||||
|
||||
__global__ void Kernels::vec_scalar_add(
|
||||
const float* __restrict__ d_src,
|
||||
float* __restrict__ d_out,
|
||||
const float* __restrict__ d_scalar,
|
||||
const unsigned int len
|
||||
) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= len) {
|
||||
return;
|
||||
}
|
||||
d_out[tid] = d_src[tid] + *d_scalar;
|
||||
}
|
||||
|
||||
__global__ void Kernels::vec_scalar_div(
|
||||
const float* __restrict__ d_src,
|
||||
float* __restrict__ d_out,
|
||||
const float* __restrict__ d_scalar,
|
||||
const unsigned int len
|
||||
) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= len) {
|
||||
return;
|
||||
}
|
||||
d_out[tid] = d_src[tid] / *d_scalar;
|
||||
}
|
||||
|
||||
__global__ void Kernels::vec_scalar_mul(
|
||||
const float* __restrict__ d_src,
|
||||
float* __restrict__ d_out,
|
||||
const float* __restrict__ d_scalar,
|
||||
const unsigned int len
|
||||
) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= len) {
|
||||
return;
|
||||
}
|
||||
d_out[tid] = d_src[tid] * *d_scalar;
|
||||
}
|
||||
|
||||
__global__ void Kernels::vec_exp(
|
||||
const float* __restrict__ src,
|
||||
float* __restrict__ dst,
|
||||
const unsigned int len
|
||||
) {
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
for (int i = tid; i < len; i += stride) {
|
||||
dst[i] = expf(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Kernels::vec_sqrt(
|
||||
const float* __restrict__ src,
|
||||
float* __restrict__ dst,
|
||||
const unsigned int len
|
||||
) {
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
for (int i = tid; i < len; i += stride) {
|
||||
dst[i] = sqrtf(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Kernels::vec_scale(
|
||||
const float* __restrict__ src,
|
||||
float* __restrict__ dst,
|
||||
const float* __restrict__ scale,
|
||||
const float* epsilon,
|
||||
const unsigned int len
|
||||
) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < len) {
|
||||
float inv_std = rsqrtf(*scale + *epsilon);
|
||||
dst[idx] = src[idx] * inv_std;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Kernels::max_reduce(
|
||||
const float* __restrict__ d_vector,
|
||||
float* __restrict__ d_output,
|
||||
const unsigned int len
|
||||
) {
|
||||
__shared__ float shared_max[BLOCK_SIZE];
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (i < len) {
|
||||
shared_max[threadIdx.x] = d_vector[i];
|
||||
} else {
|
||||
shared_max[threadIdx.x] = -INFINITY;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||
if (threadIdx.x < s) {
|
||||
shared_max[threadIdx.x] = fmaxf(shared_max[threadIdx.x], shared_max[threadIdx.x + s]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
d_output[blockIdx.x] = shared_max[0];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Kernels::sum_reduce(
|
||||
const float* __restrict__ d_vector,
|
||||
float* __restrict__ d_output,
|
||||
const unsigned int len
|
||||
) {
|
||||
__shared__ float partial_sum[BLOCK_SIZE];
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (i < len) {
|
||||
partial_sum[threadIdx.x] = d_vector[i];
|
||||
} else {
|
||||
partial_sum[threadIdx.x] = 0.0f;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||
if (threadIdx.x < s) {
|
||||
partial_sum[threadIdx.x] += partial_sum[threadIdx.x + s];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
d_output[blockIdx.x] = partial_sum[0];
|
||||
}
|
||||
}
|
||||
85
src/backends/cuda/kernels/pooling.cu
Normal file
85
src/backends/cuda/kernels/pooling.cu
Normal file
@@ -0,0 +1,85 @@
|
||||
#include "cuda_helper.cuh"
|
||||
#include "layer.cuh"
|
||||
#include "pooling.cuh"
|
||||
|
||||
using namespace CUDANet;
|
||||
|
||||
__global__ void Kernels::max_pooling(
|
||||
const float* __restrict__ d_input,
|
||||
float* __restrict__ d_output,
|
||||
const shape2d inputSize,
|
||||
const shape2d outputSize,
|
||||
const int nChannels,
|
||||
const shape2d poolingSize,
|
||||
const shape2d stride,
|
||||
const shape2d padding
|
||||
) {
|
||||
int j = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
int c = blockDim.z * blockIdx.z + threadIdx.z;
|
||||
|
||||
if (i >= outputSize.first || j >= outputSize.second || c >= nChannels) {
|
||||
return;
|
||||
}
|
||||
|
||||
float max = 0.0f;
|
||||
|
||||
for (int k = 0; k < poolingSize.first; k++) {
|
||||
for (int l = 0; l < poolingSize.second; l++) {
|
||||
int inputRow = i * stride.first + k - padding.first;
|
||||
int inputCol = j * stride.second + l - padding.second;
|
||||
|
||||
if (inputRow >= 0 && inputRow < inputSize.first && inputCol >= 0 &&
|
||||
inputCol < inputSize.second) {
|
||||
int inputIndex = c * inputSize.first * inputSize.second +
|
||||
inputRow * inputSize.second + inputCol;
|
||||
if (d_input[inputIndex] > max) {
|
||||
max = d_input[inputIndex];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d_output
|
||||
[c * outputSize.first * outputSize.second + i * outputSize.second + j] =
|
||||
max;
|
||||
}
|
||||
|
||||
__global__ void Kernels::avg_pooling(
|
||||
const float* __restrict__ d_input,
|
||||
float* __restrict__ d_output,
|
||||
const shape2d inputSize,
|
||||
const shape2d outputSize,
|
||||
const int nChannels,
|
||||
const shape2d poolingSize,
|
||||
const shape2d stride,
|
||||
const shape2d padding
|
||||
) {
|
||||
int j = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int i = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
int c = blockDim.z * blockIdx.z + threadIdx.z;
|
||||
|
||||
if (i >= outputSize.first || j >= outputSize.second || c >= nChannels) {
|
||||
return;
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int k = 0; k < poolingSize.first; k++) {
|
||||
for (int l = 0; l < poolingSize.second; l++) {
|
||||
int inputRow = i * stride.first + k - padding.first;
|
||||
int inputCol = j * stride.second + l - padding.second;
|
||||
|
||||
if (inputRow >= 0 && inputRow < inputSize.first && inputCol >= 0 &&
|
||||
inputCol < inputSize.second) {
|
||||
int inputIndex = c * inputSize.first * inputSize.second +
|
||||
inputRow * inputSize.second + inputCol;
|
||||
sum += d_input[inputIndex];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d_output
|
||||
[c * outputSize.first * outputSize.second + i * outputSize.second + j] =
|
||||
sum / (poolingSize.first * poolingSize.second);
|
||||
}
|
||||
28
src/backends/cuda/layers/add.cu
Normal file
28
src/backends/cuda/layers/add.cu
Normal file
@@ -0,0 +1,28 @@
|
||||
#include "add.hpp"
|
||||
#include "matmul.cuh"
|
||||
#include "cuda_helper.cuh"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
void Add::initCUDA() {
|
||||
d_output = nullptr;
|
||||
CUDA_CHECK(cudaMalloc((void**)&d_output, sizeof(float) * inputSize));
|
||||
|
||||
gridSize = (inputSize + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
}
|
||||
|
||||
void Add::delCUDA() {
|
||||
cudaFree(d_output);
|
||||
}
|
||||
|
||||
float* Add::forwardCUDA(const float* d_inputA, const float* d_inputB) {
|
||||
|
||||
Kernels::vec_vec_add<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_inputA, d_inputB, d_output, inputSize
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return d_output;
|
||||
|
||||
}
|
||||
26
src/backends/cuda/utils/cuda_helper.cu
Normal file
26
src/backends/cuda/utils/cuda_helper.cu
Normal file
@@ -0,0 +1,26 @@
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "cuda_helper.cuh"
|
||||
|
||||
cudaDeviceProp initializeCUDA() {
|
||||
int deviceCount;
|
||||
CUDA_CHECK(cudaGetDeviceCount(&deviceCount));
|
||||
|
||||
if (deviceCount == 0) {
|
||||
std::fprintf(stderr, "No CUDA devices found. Exiting.\n");
|
||||
std::exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
int device = 0;
|
||||
CUDA_CHECK(cudaSetDevice(device));
|
||||
|
||||
cudaDeviceProp deviceProp;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&deviceProp, device));
|
||||
|
||||
std::printf("Using CUDA device %d: %s\n", device, deviceProp.name);
|
||||
|
||||
return deviceProp;
|
||||
}
|
||||
107
src/backends/cuda/utils/vector.cu
Normal file
107
src/backends/cuda/utils/vector.cu
Normal file
@@ -0,0 +1,107 @@
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "vector.cuh"
|
||||
#include "matmul.cuh"
|
||||
#include "cuda_helper.cuh"
|
||||
|
||||
using namespace CUDANet;
|
||||
|
||||
void Utils::print_vec(const float* d_vec, const unsigned int length) {
|
||||
std::vector<float> h_vec(length);
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
h_vec.data(), d_vec, sizeof(float) * length, cudaMemcpyDeviceToHost
|
||||
));
|
||||
|
||||
for (int i = 0; i < length; ++i) {
|
||||
std::cout << h_vec[i] << ", ";
|
||||
}
|
||||
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
void Utils::clear(float* d_vec, const unsigned int length) {
|
||||
CUDA_CHECK(cudaMemset(d_vec, 0, sizeof(float) * length));
|
||||
}
|
||||
|
||||
void Utils::max(const float* d_vec, float* d_max, const unsigned int length) {
|
||||
|
||||
const int grid_size = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
Kernels::max_reduce<<<grid_size, BLOCK_SIZE>>>(d_vec, d_max, length);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
int remaining = grid_size;
|
||||
|
||||
while (remaining > 1) {
|
||||
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
CUDANet::Kernels::max_reduce<<<blocks_needed, BLOCK_SIZE>>>(d_max, d_max, remaining);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
remaining = blocks_needed;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void Utils::sum(const float* d_vec, float* d_sum, const unsigned int length) {
|
||||
|
||||
const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
CUDANet::Kernels::sum_reduce<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_vec, d_sum, length
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
int remaining = gridSize;
|
||||
while (remaining > 1) {
|
||||
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
CUDANet::Kernels::sum_reduce<<<blocks_needed, BLOCK_SIZE>>>(d_sum, d_sum, remaining);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
remaining = blocks_needed;
|
||||
}
|
||||
}
|
||||
|
||||
void Utils::mean(const float* d_vec, float* d_mean, float *d_length, int length) {
|
||||
Utils::sum(d_vec, d_mean, length);
|
||||
|
||||
const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_mean,
|
||||
d_mean,
|
||||
d_length,
|
||||
length
|
||||
);
|
||||
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
|
||||
void Utils::var(float* d_vec, float* d_var, float *d_length, const unsigned int length) {
|
||||
|
||||
const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
Kernels::vec_vec_mul<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_vec,
|
||||
d_vec,
|
||||
d_var,
|
||||
length
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Sum over all differences
|
||||
Utils::sum(
|
||||
d_var,
|
||||
d_var,
|
||||
length
|
||||
);
|
||||
|
||||
// Divide by difference sum / length -> variance
|
||||
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_var,
|
||||
d_var,
|
||||
d_length,
|
||||
length
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user