Improve softmax numerical stability

This commit is contained in:
2024-04-08 23:25:46 +02:00
parent e419a93408
commit b49dddf34a
6 changed files with 119 additions and 4 deletions

View File

@@ -49,3 +49,38 @@ __global__ void Kernels::vec_vec_add(
}
d_output[tid] = d_vector1[tid] + d_vector2[tid];
}
__global__ void Kernels::max_reduce(
const float* __restrict__ d_vector,
float* __restrict__ d_output
) {
__shared__ float shared_max[BLOCK_SIZE];
int i = blockIdx.x * blockDim.x + threadIdx.x;
shared_max[threadIdx.x] = d_vector[i];
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s) {
shared_max[threadIdx.x] = fmaxf(shared_max[threadIdx.x], shared_max[threadIdx.x + s]);
}
__syncthreads();
}
if (threadIdx.x == 0) {
d_output[blockIdx.x] = shared_max[0];
}
}
__global__ void Kernels::vec_scalar_sub(
const float* __restrict__ d_vector,
const float* __restrict__ d_scalar,
float* __restrict__ d_output,
const unsigned int w
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= w) {
return;
}
d_output[tid] = d_vector[tid] - d_scalar[0];
}

View File

@@ -2,6 +2,10 @@
#include "cuda_helper.cuh"
#include "activation_functions.cuh"
#include "matmul.cuh"
#include <iostream>
#include <vector>
using namespace CUDANet::Layers;
@@ -11,6 +15,9 @@ Activation::Activation(ActivationType activation, const unsigned int length)
if (activationType == SOFTMAX) {
d_softmax_sum = nullptr;
CUDA_CHECK(cudaMalloc((void**)&d_softmax_sum, sizeof(float) * length));
d_max = nullptr;
CUDA_CHECK(cudaMalloc((void**)&d_max, sizeof(float) * length));
}
gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
@@ -37,6 +44,21 @@ void Activation::activate(float* __restrict__ d_input) {
);
break;
case SOFTMAX:
// Find max value
Kernels::max_reduce<<<gridSize, BLOCK_SIZE>>>(
d_input, d_max
);
Kernels::max_reduce<<<1, BLOCK_SIZE>>>(
d_max, d_max
);
// Subtract max value to improve numerical stability
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
d_input, d_max, d_input, length
);
// Compute softmax
Kernels::softmax_exp<<<gridSize, BLOCK_SIZE>>>(
d_input, d_input, length
);