Move softmax partial kernels to matmul

This commit is contained in:
2024-04-11 22:01:47 +02:00
parent bf7c961b9e
commit 710a33bdde
6 changed files with 274 additions and 212 deletions

View File

@@ -28,51 +28,3 @@ __global__ void Kernels::relu(
dst[i] = src[i] < 0.0 ? 0.0 : src[i];
}
}
__global__ void Kernels::softmax_exp(
const float* __restrict__ src,
float* __restrict__ dst,
const unsigned int len
) {
int stride = gridDim.x * blockDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < len; i += stride) {
dst[i] = expf(src[i]);
}
}
__global__ void Kernels::softmax_sum(
const float* __restrict__ d_vector,
float* __restrict__ d_output
) {
__shared__ float partial_sum[BLOCK_SIZE];
int i = blockIdx.x * blockDim.x + threadIdx.x;
partial_sum[threadIdx.x] = d_vector[i];
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s) {
partial_sum[threadIdx.x] += partial_sum[threadIdx.x + s];
}
__syncthreads();
}
if (threadIdx.x == 0) {
d_output[blockIdx.x] = partial_sum[0];
}
}
__global__ void Kernels::softmax_div(
const float* __restrict__ src,
float* __restrict__ dst,
const float* __restrict__ sum,
const unsigned int len
) {
int stride = gridDim.x * blockDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < len; i += stride) {
dst[i] = src[i] / sum[0];
}
}

View File

@@ -3,6 +3,7 @@
using namespace CUDANet;
__global__ void Kernels::mat_vec_mul(
const float* __restrict__ d_matrix,
const float* __restrict__ d_vector,
@@ -37,6 +38,7 @@ __global__ void Kernels::mat_vec_mul(
d_output[tid] = temp;
}
__global__ void Kernels::vec_vec_add(
const float* __restrict__ d_vector1,
const float* __restrict__ d_vector2,
@@ -50,14 +52,75 @@ __global__ void Kernels::vec_vec_add(
d_output[tid] = d_vector1[tid] + d_vector2[tid];
}
__global__ void Kernels::vec_scalar_sub(
const float* __restrict__ d_src,
float* __restrict__ d_out,
const float* __restrict__ d_scalar,
const unsigned int len
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) {
return;
}
d_out[tid] = d_src[tid] - d_scalar[0];
}
__global__ void Kernels::vec_scalar_div(
const float* __restrict__ d_src,
float* __restrict__ d_out,
const float* __restrict__ d_scalar,
const unsigned int len
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) {
return;
}
d_out[tid] = d_src[tid] / d_scalar[0];
}
__global__ void Kernels::vec_exp(
const float* __restrict__ src,
float* __restrict__ dst,
const unsigned int len
) {
int stride = gridDim.x * blockDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < len; i += stride) {
dst[i] = expf(src[i]);
}
}
__global__ void Kernels::clear(
float* __restrict__ d_vector,
const unsigned int w
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= w) {
return;
}
d_vector[tid] = 0.0f;
}
__global__ void Kernels::max_reduce(
const float* __restrict__ d_vector,
float* __restrict__ d_output
float* __restrict__ d_output,
const unsigned int len
) {
__shared__ float shared_max[BLOCK_SIZE];
int i = blockIdx.x * blockDim.x + threadIdx.x;
shared_max[threadIdx.x] = d_vector[i];
if (i < len) {
shared_max[threadIdx.x] = d_vector[i];
} else {
shared_max[threadIdx.x] = -INFINITY;
}
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
@@ -72,26 +135,30 @@ __global__ void Kernels::max_reduce(
}
}
__global__ void Kernels::vec_scalar_sub(
__global__ void Kernels::sum_reduce(
const float* __restrict__ d_vector,
const float* __restrict__ d_scalar,
float* __restrict__ d_output,
const unsigned int w
const unsigned int len
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= w) {
return;
}
d_output[tid] = d_vector[tid] - d_scalar[0];
}
__shared__ float partial_sum[BLOCK_SIZE];
int i = blockIdx.x * blockDim.x + threadIdx.x;
__global__ void Kernels::clear(
float* __restrict__ d_vector,
const unsigned int w
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= w) {
return;
if (i < len) {
partial_sum[threadIdx.x] = d_vector[i];
} else {
partial_sum[threadIdx.x] = 0.0f;
}
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s) {
partial_sum[threadIdx.x] += partial_sum[threadIdx.x + s];
}
__syncthreads();
}
if (threadIdx.x == 0) {
d_output[blockIdx.x] = partial_sum[0];
}
d_vector[tid] = 0.0f;
}