Restructure cuda backend

This commit is contained in:
2024-09-05 22:23:47 +02:00
parent 65727dfee8
commit f8220f0ec1
19 changed files with 69 additions and 16 deletions

View File

@@ -0,0 +1,30 @@
#include "activation_functions.cuh"
#include "cuda_helper.cuh"
using namespace CUDANet;
__global__ void Kernels::sigmoid(
const float* __restrict__ src,
float* __restrict__ dst,
const unsigned int len
) {
int stride = gridDim.x * blockDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < len; i += stride) {
dst[i] = 1.0 / (1.0 + exp(-src[i]));
}
}
__global__ void Kernels::relu(
const float* __restrict__ src,
float* __restrict__ dst,
const unsigned int len
) {
int stride = gridDim.x * blockDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < len; i += stride) {
dst[i] = src[i] < 0.0 ? 0.0 : src[i];
}
}

View File

@@ -0,0 +1,60 @@
#include <iostream>
#include "convolution.cuh"
using namespace CUDANet;
__global__ void Kernels::convolution(
const float* __restrict__ d_input,
const float* __restrict__ d_kernel,
const float* __restrict__ d_bias,
float* __restrict__ d_output,
const shape2d inputSize,
const int nChannels,
const shape2d paddingSize,
const shape2d kernelSize,
const shape2d stride,
const int nFilters,
const shape2d outputSize
) {
int j = blockDim.x * blockIdx.x + threadIdx.x;
int i = blockDim.y * blockIdx.y + threadIdx.y;
int f = blockDim.z * blockIdx.z + threadIdx.z;
if (i >= outputSize.first || j >= outputSize.second || f >= nFilters) {
return;
}
float sum = 0.0f;
// Iterate over kernel and input matrix
for (int c = 0; c < nChannels; c++) {
for (int k = 0; k < kernelSize.first; k++) {
for (int l = 0; l < kernelSize.second; l++) {
// if i, j is in the padding region
if (i * stride.first + k < paddingSize.first ||
i * stride.first + k >=
(inputSize.first + paddingSize.first) ||
j * stride.second + l < paddingSize.second ||
j * stride.second + l >=
(inputSize.second + paddingSize.second)) {
continue;
}
int kernelIndex =
f * kernelSize.first * kernelSize.second * nChannels +
c * kernelSize.first * kernelSize.second +
k * kernelSize.second + l;
int inputIndex = c * inputSize.first * inputSize.second +
(i * stride.first + k - paddingSize.first) *
inputSize.second +
(j * stride.second + l - paddingSize.second);
sum += d_kernel[kernelIndex] * d_input[inputIndex];
}
}
}
d_output[f * outputSize.first * outputSize.second + i * outputSize.second + j] =
sum + d_bias[f];
}

View File

@@ -0,0 +1,211 @@
#include "cuda_helper.cuh"
#include "matmul.cuh"
using namespace CUDANet;
__global__ void Kernels::mat_vec_mul(
const float* __restrict__ d_matrix,
const float* __restrict__ d_vector,
float* __restrict__ d_output,
const unsigned int w,
const unsigned int h
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid < h) {
float temp = 0.0f;
for (unsigned int j = 0; j < w; j++) {
temp += d_matrix[tid * w + j] * d_vector[j];
}
d_output[tid] = temp;
}
}
__global__ void Kernels::vec_vec_add(
const float* __restrict__ d_vector1,
const float* __restrict__ d_vector2,
float* __restrict__ d_output,
const unsigned int w
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= w) {
return;
}
d_output[tid] = d_vector1[tid] + d_vector2[tid];
}
__global__ void Kernels::vec_vec_sub(
const float* __restrict__ d_vector1,
const float* __restrict__ d_vector2,
float* __restrict__ d_output,
const unsigned int w
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= w) {
return;
}
d_output[tid] = d_vector1[tid] - d_vector2[tid];
}
__global__ void Kernels::vec_vec_mul(
const float* __restrict__ d_vector1,
const float* __restrict__ d_vector2,
float* __restrict__ d_output,
const unsigned int w
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= w) {
return;
}
d_output[tid] = d_vector1[tid] * d_vector2[tid];
}
__global__ void Kernels::vec_scalar_sub(
const float* __restrict__ d_src,
float* __restrict__ d_out,
const float* __restrict__ d_scalar,
const unsigned int len
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) {
return;
}
d_out[tid] = d_src[tid] - *d_scalar;
}
__global__ void Kernels::vec_scalar_add(
const float* __restrict__ d_src,
float* __restrict__ d_out,
const float* __restrict__ d_scalar,
const unsigned int len
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) {
return;
}
d_out[tid] = d_src[tid] + *d_scalar;
}
__global__ void Kernels::vec_scalar_div(
const float* __restrict__ d_src,
float* __restrict__ d_out,
const float* __restrict__ d_scalar,
const unsigned int len
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) {
return;
}
d_out[tid] = d_src[tid] / *d_scalar;
}
__global__ void Kernels::vec_scalar_mul(
const float* __restrict__ d_src,
float* __restrict__ d_out,
const float* __restrict__ d_scalar,
const unsigned int len
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) {
return;
}
d_out[tid] = d_src[tid] * *d_scalar;
}
__global__ void Kernels::vec_exp(
const float* __restrict__ src,
float* __restrict__ dst,
const unsigned int len
) {
int stride = gridDim.x * blockDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < len; i += stride) {
dst[i] = expf(src[i]);
}
}
__global__ void Kernels::vec_sqrt(
const float* __restrict__ src,
float* __restrict__ dst,
const unsigned int len
) {
int stride = gridDim.x * blockDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < len; i += stride) {
dst[i] = sqrtf(src[i]);
}
}
__global__ void Kernels::vec_scale(
const float* __restrict__ src,
float* __restrict__ dst,
const float* __restrict__ scale,
const float* epsilon,
const unsigned int len
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < len) {
float inv_std = rsqrtf(*scale + *epsilon);
dst[idx] = src[idx] * inv_std;
}
}
__global__ void Kernels::max_reduce(
const float* __restrict__ d_vector,
float* __restrict__ d_output,
const unsigned int len
) {
__shared__ float shared_max[BLOCK_SIZE];
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
shared_max[threadIdx.x] = d_vector[i];
} else {
shared_max[threadIdx.x] = -INFINITY;
}
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s) {
shared_max[threadIdx.x] = fmaxf(shared_max[threadIdx.x], shared_max[threadIdx.x + s]);
}
__syncthreads();
}
if (threadIdx.x == 0) {
d_output[blockIdx.x] = shared_max[0];
}
}
__global__ void Kernels::sum_reduce(
const float* __restrict__ d_vector,
float* __restrict__ d_output,
const unsigned int len
) {
__shared__ float partial_sum[BLOCK_SIZE];
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
partial_sum[threadIdx.x] = d_vector[i];
} else {
partial_sum[threadIdx.x] = 0.0f;
}
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s) {
partial_sum[threadIdx.x] += partial_sum[threadIdx.x + s];
}
__syncthreads();
}
if (threadIdx.x == 0) {
d_output[blockIdx.x] = partial_sum[0];
}
}

View File

@@ -0,0 +1,85 @@
#include "cuda_helper.cuh"
#include "layer.cuh"
#include "pooling.cuh"
using namespace CUDANet;
__global__ void Kernels::max_pooling(
const float* __restrict__ d_input,
float* __restrict__ d_output,
const shape2d inputSize,
const shape2d outputSize,
const int nChannels,
const shape2d poolingSize,
const shape2d stride,
const shape2d padding
) {
int j = blockDim.x * blockIdx.x + threadIdx.x;
int i = blockDim.y * blockIdx.y + threadIdx.y;
int c = blockDim.z * blockIdx.z + threadIdx.z;
if (i >= outputSize.first || j >= outputSize.second || c >= nChannels) {
return;
}
float max = 0.0f;
for (int k = 0; k < poolingSize.first; k++) {
for (int l = 0; l < poolingSize.second; l++) {
int inputRow = i * stride.first + k - padding.first;
int inputCol = j * stride.second + l - padding.second;
if (inputRow >= 0 && inputRow < inputSize.first && inputCol >= 0 &&
inputCol < inputSize.second) {
int inputIndex = c * inputSize.first * inputSize.second +
inputRow * inputSize.second + inputCol;
if (d_input[inputIndex] > max) {
max = d_input[inputIndex];
}
}
}
}
d_output
[c * outputSize.first * outputSize.second + i * outputSize.second + j] =
max;
}
__global__ void Kernels::avg_pooling(
const float* __restrict__ d_input,
float* __restrict__ d_output,
const shape2d inputSize,
const shape2d outputSize,
const int nChannels,
const shape2d poolingSize,
const shape2d stride,
const shape2d padding
) {
int j = blockDim.x * blockIdx.x + threadIdx.x;
int i = blockDim.y * blockIdx.y + threadIdx.y;
int c = blockDim.z * blockIdx.z + threadIdx.z;
if (i >= outputSize.first || j >= outputSize.second || c >= nChannels) {
return;
}
float sum = 0.0f;
for (int k = 0; k < poolingSize.first; k++) {
for (int l = 0; l < poolingSize.second; l++) {
int inputRow = i * stride.first + k - padding.first;
int inputCol = j * stride.second + l - padding.second;
if (inputRow >= 0 && inputRow < inputSize.first && inputCol >= 0 &&
inputCol < inputSize.second) {
int inputIndex = c * inputSize.first * inputSize.second +
inputRow * inputSize.second + inputCol;
sum += d_input[inputIndex];
}
}
}
d_output
[c * outputSize.first * outputSize.second + i * outputSize.second + j] =
sum / (poolingSize.first * poolingSize.second);
}

View File

@@ -0,0 +1,28 @@
#include "add.hpp"
#include "matmul.cuh"
#include "cuda_helper.cuh"
using namespace CUDANet::Layers;
void Add::initCUDA() {
d_output = nullptr;
CUDA_CHECK(cudaMalloc((void**)&d_output, sizeof(float) * inputSize));
gridSize = (inputSize + BLOCK_SIZE - 1) / BLOCK_SIZE;
}
void Add::delCUDA() {
cudaFree(d_output);
}
float* Add::forwardCUDA(const float* d_inputA, const float* d_inputB) {
Kernels::vec_vec_add<<<gridSize, BLOCK_SIZE>>>(
d_inputA, d_inputB, d_output, inputSize
);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
return d_output;
}

View File

@@ -0,0 +1,26 @@
#include <cuda_runtime.h>
#include <cstdio>
#include <cstdlib>
#include "cuda_helper.cuh"
cudaDeviceProp initializeCUDA() {
int deviceCount;
CUDA_CHECK(cudaGetDeviceCount(&deviceCount));
if (deviceCount == 0) {
std::fprintf(stderr, "No CUDA devices found. Exiting.\n");
std::exit(EXIT_FAILURE);
}
int device = 0;
CUDA_CHECK(cudaSetDevice(device));
cudaDeviceProp deviceProp;
CUDA_CHECK(cudaGetDeviceProperties(&deviceProp, device));
std::printf("Using CUDA device %d: %s\n", device, deviceProp.name);
return deviceProp;
}

View File

@@ -0,0 +1,107 @@
#include <iostream>
#include <vector>
#include "vector.cuh"
#include "matmul.cuh"
#include "cuda_helper.cuh"
using namespace CUDANet;
void Utils::print_vec(const float* d_vec, const unsigned int length) {
std::vector<float> h_vec(length);
CUDA_CHECK(cudaMemcpy(
h_vec.data(), d_vec, sizeof(float) * length, cudaMemcpyDeviceToHost
));
for (int i = 0; i < length; ++i) {
std::cout << h_vec[i] << ", ";
}
std::cout << std::endl;
}
void Utils::clear(float* d_vec, const unsigned int length) {
CUDA_CHECK(cudaMemset(d_vec, 0, sizeof(float) * length));
}
void Utils::max(const float* d_vec, float* d_max, const unsigned int length) {
const int grid_size = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
Kernels::max_reduce<<<grid_size, BLOCK_SIZE>>>(d_vec, d_max, length);
CUDA_CHECK(cudaGetLastError());
int remaining = grid_size;
while (remaining > 1) {
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
CUDANet::Kernels::max_reduce<<<blocks_needed, BLOCK_SIZE>>>(d_max, d_max, remaining);
CUDA_CHECK(cudaGetLastError());
remaining = blocks_needed;
}
}
void Utils::sum(const float* d_vec, float* d_sum, const unsigned int length) {
const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
CUDANet::Kernels::sum_reduce<<<gridSize, BLOCK_SIZE>>>(
d_vec, d_sum, length
);
CUDA_CHECK(cudaGetLastError());
int remaining = gridSize;
while (remaining > 1) {
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
CUDANet::Kernels::sum_reduce<<<blocks_needed, BLOCK_SIZE>>>(d_sum, d_sum, remaining);
CUDA_CHECK(cudaGetLastError());
remaining = blocks_needed;
}
}
void Utils::mean(const float* d_vec, float* d_mean, float *d_length, int length) {
Utils::sum(d_vec, d_mean, length);
const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
d_mean,
d_mean,
d_length,
length
);
CUDA_CHECK(cudaGetLastError());
}
void Utils::var(float* d_vec, float* d_var, float *d_length, const unsigned int length) {
const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
Kernels::vec_vec_mul<<<gridSize, BLOCK_SIZE>>>(
d_vec,
d_vec,
d_var,
length
);
CUDA_CHECK(cudaGetLastError());
// Sum over all differences
Utils::sum(
d_var,
d_var,
length
);
// Divide by difference sum / length -> variance
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
d_var,
d_var,
d_length,
length
);
CUDA_CHECK(cudaGetLastError());
}