Move softmax partial kernels to matmul

This commit is contained in:
2024-04-11 22:01:47 +02:00
parent bf7c961b9e
commit 710a33bdde
6 changed files with 274 additions and 212 deletions

View File

@@ -29,45 +29,6 @@ __global__ void relu(
const unsigned int len const unsigned int len
); );
/**
* @brief Softmax activation exponentiation kernel
*
* @param src Pointer to the source array
* @param dst Pointer to the destination array
* @param len Length of the arrays
*/
__global__ void softmax_exp(
const float* __restrict__ src,
float* __restrict__ dst,
const unsigned int len
);
/**
* @brief
*
* @param d_vector Device pointer to vector
* @param d_output Device pointer to output vector
* @param w Length of the vector
*/
__global__ void softmax_sum(
const float* __restrict__ d_vector,
float* __restrict__ d_output
);
/**
* @brief Softmax activation function kernel
*
* @param src Pointer to the source array
* @param dst Pointer to the destination array
* @param len Length of the arrays
*/
__global__ void softmax_div(
const float* __restrict__ src,
float* __restrict__ dst,
const float* __restrict__ sum,
const unsigned int len
);
} // namespace CUDANet::Kernels } // namespace CUDANet::Kernels
#endif // CUDANET_ACTIVATION_FUNCTIONS_H #endif // CUDANET_ACTIVATION_FUNCTIONS_H

View File

@@ -35,17 +35,6 @@ __global__ void vec_vec_add(
const unsigned int w const unsigned int w
); );
/**
* @brief Max reduction kernel
*
* @param d_vector Device pointer to vector
* @param d_output Device pointer to output vector
*/
__global__ void max_reduce(
const float* __restrict__ d_vector,
float* __restrict__ d_output
);
/** /**
* @brief Add scalar to each element of the vector * @brief Add scalar to each element of the vector
* *
@@ -56,15 +45,68 @@ __global__ void max_reduce(
* @return __global__ * @return __global__
*/ */
__global__ void vec_scalar_sub( __global__ void vec_scalar_sub(
const float* __restrict__ d_vector, const float* __restrict__ d_src,
float* __restrict__ d_out,
const float* __restrict__ d_scalar, const float* __restrict__ d_scalar,
float* __restrict__ d_output, const unsigned int len
const unsigned int w
); );
/**
* @brief Softmax activation function kernel
*
* @param src Pointer to the source array
* @param dst Pointer to the destination array
* @param len Length of the arrays
*/
__global__ void vec_scalar_div(
const float* __restrict__ d_src,
float* __restrict__ d_out,
const float* __restrict__ d_scalar,
const unsigned int len
);
/**
* @brief Softmax activation exponentiation kernel
*
* @param src Pointer to the source array
* @param dst Pointer to the destination array
* @param len Length of the arrays
*/
__global__ void vec_exp(
const float* __restrict__ src,
float* __restrict__ dst,
const unsigned int len
);
/**
* @brief Max reduction kernel
*
* @param d_vector Device pointer to vector
* @param d_output Device pointer to output vector
*/
__global__ void max_reduce(
const float* __restrict__ d_vector,
float* __restrict__ d_output,
const unsigned int len
);
/**
* @brief
*
* @param d_vector Device pointer to vector
* @param d_output Device pointer to output vector
* @param len Length of the vector
*/
__global__ void sum_reduce(
const float* __restrict__ d_vector,
float* __restrict__ d_output,
const unsigned int len
);
__global__ void clear( __global__ void clear(
float* __restrict__ d_vector, float* __restrict__ d_vector,
const unsigned int w const unsigned int len
); );
} // namespace CUDANet::Kernels } // namespace CUDANet::Kernels

View File

@@ -28,51 +28,3 @@ __global__ void Kernels::relu(
dst[i] = src[i] < 0.0 ? 0.0 : src[i]; dst[i] = src[i] < 0.0 ? 0.0 : src[i];
} }
} }
__global__ void Kernels::softmax_exp(
const float* __restrict__ src,
float* __restrict__ dst,
const unsigned int len
) {
int stride = gridDim.x * blockDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < len; i += stride) {
dst[i] = expf(src[i]);
}
}
__global__ void Kernels::softmax_sum(
const float* __restrict__ d_vector,
float* __restrict__ d_output
) {
__shared__ float partial_sum[BLOCK_SIZE];
int i = blockIdx.x * blockDim.x + threadIdx.x;
partial_sum[threadIdx.x] = d_vector[i];
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s) {
partial_sum[threadIdx.x] += partial_sum[threadIdx.x + s];
}
__syncthreads();
}
if (threadIdx.x == 0) {
d_output[blockIdx.x] = partial_sum[0];
}
}
__global__ void Kernels::softmax_div(
const float* __restrict__ src,
float* __restrict__ dst,
const float* __restrict__ sum,
const unsigned int len
) {
int stride = gridDim.x * blockDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < len; i += stride) {
dst[i] = src[i] / sum[0];
}
}

View File

@@ -3,6 +3,7 @@
using namespace CUDANet; using namespace CUDANet;
__global__ void Kernels::mat_vec_mul( __global__ void Kernels::mat_vec_mul(
const float* __restrict__ d_matrix, const float* __restrict__ d_matrix,
const float* __restrict__ d_vector, const float* __restrict__ d_vector,
@@ -37,6 +38,7 @@ __global__ void Kernels::mat_vec_mul(
d_output[tid] = temp; d_output[tid] = temp;
} }
__global__ void Kernels::vec_vec_add( __global__ void Kernels::vec_vec_add(
const float* __restrict__ d_vector1, const float* __restrict__ d_vector1,
const float* __restrict__ d_vector2, const float* __restrict__ d_vector2,
@@ -50,14 +52,75 @@ __global__ void Kernels::vec_vec_add(
d_output[tid] = d_vector1[tid] + d_vector2[tid]; d_output[tid] = d_vector1[tid] + d_vector2[tid];
} }
__global__ void Kernels::vec_scalar_sub(
const float* __restrict__ d_src,
float* __restrict__ d_out,
const float* __restrict__ d_scalar,
const unsigned int len
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) {
return;
}
d_out[tid] = d_src[tid] - d_scalar[0];
}
__global__ void Kernels::vec_scalar_div(
const float* __restrict__ d_src,
float* __restrict__ d_out,
const float* __restrict__ d_scalar,
const unsigned int len
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) {
return;
}
d_out[tid] = d_src[tid] / d_scalar[0];
}
__global__ void Kernels::vec_exp(
const float* __restrict__ src,
float* __restrict__ dst,
const unsigned int len
) {
int stride = gridDim.x * blockDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < len; i += stride) {
dst[i] = expf(src[i]);
}
}
__global__ void Kernels::clear(
float* __restrict__ d_vector,
const unsigned int w
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= w) {
return;
}
d_vector[tid] = 0.0f;
}
__global__ void Kernels::max_reduce( __global__ void Kernels::max_reduce(
const float* __restrict__ d_vector, const float* __restrict__ d_vector,
float* __restrict__ d_output float* __restrict__ d_output,
const unsigned int len
) { ) {
__shared__ float shared_max[BLOCK_SIZE]; __shared__ float shared_max[BLOCK_SIZE];
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
shared_max[threadIdx.x] = d_vector[i]; if (i < len) {
shared_max[threadIdx.x] = d_vector[i];
} else {
shared_max[threadIdx.x] = -INFINITY;
}
__syncthreads(); __syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) { for (int s = blockDim.x / 2; s > 0; s >>= 1) {
@@ -72,26 +135,30 @@ __global__ void Kernels::max_reduce(
} }
} }
__global__ void Kernels::vec_scalar_sub( __global__ void Kernels::sum_reduce(
const float* __restrict__ d_vector, const float* __restrict__ d_vector,
const float* __restrict__ d_scalar,
float* __restrict__ d_output, float* __restrict__ d_output,
const unsigned int w const unsigned int len
) { ) {
int tid = blockDim.x * blockIdx.x + threadIdx.x; __shared__ float partial_sum[BLOCK_SIZE];
if (tid >= w) { int i = blockIdx.x * blockDim.x + threadIdx.x;
return;
}
d_output[tid] = d_vector[tid] - d_scalar[0];
}
__global__ void Kernels::clear( if (i < len) {
float* __restrict__ d_vector, partial_sum[threadIdx.x] = d_vector[i];
const unsigned int w } else {
) { partial_sum[threadIdx.x] = 0.0f;
int tid = blockDim.x * blockIdx.x + threadIdx.x; }
if (tid >= w) {
return; __syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s) {
partial_sum[threadIdx.x] += partial_sum[threadIdx.x + s];
}
__syncthreads();
}
if (threadIdx.x == 0) {
d_output[blockIdx.x] = partial_sum[0];
} }
d_vector[tid] = 0.0f;
} }

View File

@@ -4,6 +4,7 @@
#include <iostream> #include <iostream>
#include "activation_functions.cuh" #include "activation_functions.cuh"
#include "matmul.cuh"
#include "cuda_helper.cuh" #include "cuda_helper.cuh"
TEST(ActivationFunctionsTest, SigmoidSanityCheck) { TEST(ActivationFunctionsTest, SigmoidSanityCheck) {
@@ -43,93 +44,24 @@ TEST(ActivationFunctionsTest, SigmoidSanityCheck) {
cudaFree(d_input); cudaFree(d_input);
cudaFree(d_output); cudaFree(d_output);
cudaDeviceReset();
} }
TEST(ActivationFunctionsTest, SoftmaxExpTest) { // void print_vec(float* d_vec, int length) {
cudaError_t cudaStatus;
float input[6] = {22.496f, 36.9006f, 30.9904f, // std::vector<float> h_vec(length);
28.4213f, 26.4541f, 31.7887f}; // CUDA_CHECK(cudaMemcpy(
// h_vec.data(), d_vec, sizeof(float) * length, cudaMemcpyDeviceToHost
// ));
std::vector<float> expected = {5886928896.0f, 1.06102872080384e+16f, // float sum = 0.0f;
28771323215872.0f, 2204012904448.0f,
308226162688.0f, 63922983927808.0f};
float* d_input; // for (int i = 0; i < length; ++i) {
float* d_output; // std::cout << h_vec[i] << ", ";
// sum += h_vec[i];
// }
cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * 6); // std::cout << std::endl;
EXPECT_EQ(cudaStatus, cudaSuccess);
cudaStatus = cudaMalloc((void**)&d_output, sizeof(float) * 6); // }
EXPECT_EQ(cudaStatus, cudaSuccess);
cudaStatus =
cudaMemcpy(d_input, input, sizeof(float) * 6, cudaMemcpyHostToDevice);
EXPECT_EQ(cudaStatus, cudaSuccess);
CUDANet::Kernels::softmax_exp<<<1, 6>>>(d_input, d_output, 6);
cudaStatus = cudaDeviceSynchronize();
EXPECT_EQ(cudaStatus, cudaSuccess);
std::vector<float> output(6);
cudaStatus = cudaMemcpy(
output.data(), d_output, sizeof(float) * 6, cudaMemcpyDeviceToHost
);
EXPECT_EQ(cudaStatus, cudaSuccess);
for (int i = 0; i < 6; i++) {
EXPECT_NEAR(expected[i], output[i], 1e7);
}
cudaFree(d_input);
cudaFree(d_output);
}
TEST(ActivationFunctionsTest, SoftmaxSumTest) {
cudaError_t cudaStatus;
const int n = 10;
std::vector<float> input(n);
for (int i = 0; i < n; i++) {
input[i] = i;
}
const float expected = n * (n - 1) / 2;
float* d_input;
float* d_sum;
const int gridSize = (n + BLOCK_SIZE - 1) / BLOCK_SIZE;
cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * n);
EXPECT_EQ(cudaStatus, cudaSuccess);
cudaStatus = cudaMalloc((void**)&d_sum, sizeof(float) * n);
EXPECT_EQ(cudaStatus, cudaSuccess);
cudaStatus =
cudaMemcpy(d_input, input.data(), sizeof(float) * n, cudaMemcpyHostToDevice);
EXPECT_EQ(cudaStatus, cudaSuccess);
CUDANet::Kernels::softmax_sum<<<gridSize, BLOCK_SIZE>>>(
d_input, d_sum
);
CUDANet::Kernels::softmax_sum<<<1, BLOCK_SIZE>>>(
d_sum, d_sum
);
CUDANet::Kernels::softmax_sum<<<1, BLOCK_SIZE>>>(
d_sum, d_sum
);
std::vector<float> sum(n);
cudaStatus = cudaMemcpy(
sum.data(), d_sum, sizeof(float) * n, cudaMemcpyDeviceToHost
);
EXPECT_EQ(cudaStatus, cudaSuccess);
EXPECT_FLOAT_EQ(expected, sum[0]);
}

View File

@@ -73,33 +73,141 @@ TEST(MatMulTest, MatVecMulTest) {
TEST(MatMulTest, MaxReduceTest) { TEST(MatMulTest, MaxReduceTest) {
cudaError_t cudaStatus; cudaError_t cudaStatus;
std::vector<float> input = {0.643f, 0.912f, 0.723f, 0.587f, 0.155f, 0.932f, 0.391f, 0.279f, 0.846f, 0.788f}; const int n = 1 << 16;
std::vector<float> input(n);
for (int i = 0; i < n; i++) {
input[i] = i;
}
float* d_input; float* d_input;
float* d_output; float* d_output;
cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * 10); cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * n);
EXPECT_EQ(cudaStatus, cudaSuccess); EXPECT_EQ(cudaStatus, cudaSuccess);
cudaStatus = cudaMalloc((void**)&d_output, sizeof(float)); cudaStatus = cudaMalloc((void**)&d_output, sizeof(float));
EXPECT_EQ(cudaStatus, cudaSuccess); EXPECT_EQ(cudaStatus, cudaSuccess);
cudaStatus = cudaMemcpy(d_input, input.data(), sizeof(float) * 10, cudaMemcpyHostToDevice); cudaStatus = cudaMemcpy(d_input, input.data(), sizeof(float) * n, cudaMemcpyHostToDevice);
EXPECT_EQ(cudaStatus, cudaSuccess); EXPECT_EQ(cudaStatus, cudaSuccess);
const int grid_size = (10 + BLOCK_SIZE - 1) / BLOCK_SIZE; const int grid_size = (n + BLOCK_SIZE - 1) / BLOCK_SIZE;
CUDANet::Kernels::max_reduce<<<grid_size, BLOCK_SIZE>>>(d_input, d_output); CUDANet::Kernels::max_reduce<<<grid_size, BLOCK_SIZE>>>(d_input, d_output, n);
CUDANet::Kernels::max_reduce<<<1, BLOCK_SIZE>>>(d_output, d_output);
std::vector<float> output(10); int remaining = grid_size;
while (remaining > 1) {
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
CUDANet::Kernels::max_reduce<<<blocks_needed, BLOCK_SIZE>>>(d_output, d_output, remaining);
remaining = blocks_needed;
}
std::vector<float> output(n);
cudaStatus = cudaMemcpy(output.data(), d_output, sizeof(float), cudaMemcpyDeviceToHost); cudaStatus = cudaMemcpy(output.data(), d_output, sizeof(float), cudaMemcpyDeviceToHost);
EXPECT_EQ(cudaStatus, cudaSuccess); EXPECT_EQ(cudaStatus, cudaSuccess);
EXPECT_EQ(output[0], 0.932f); EXPECT_EQ(output[0], 65535.0f);
cudaFree(d_input); cudaFree(d_input);
cudaFree(d_output); cudaFree(d_output);
cudaDeviceReset();
}
TEST(MatMulTest, VecExpTest) {
cudaError_t cudaStatus;
float input[6] = {22.496f, 36.9006f, 30.9904f,
28.4213f, 26.4541f, 31.7887f};
std::vector<float> expected = {5886928896.0f, 1.06102872080384e+16f,
28771323215872.0f, 2204012904448.0f,
308226162688.0f, 63922983927808.0f};
float* d_input;
float* d_output;
cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * 6);
EXPECT_EQ(cudaStatus, cudaSuccess);
cudaStatus = cudaMalloc((void**)&d_output, sizeof(float) * 6);
EXPECT_EQ(cudaStatus, cudaSuccess);
cudaStatus =
cudaMemcpy(d_input, input, sizeof(float) * 6, cudaMemcpyHostToDevice);
EXPECT_EQ(cudaStatus, cudaSuccess);
CUDANet::Kernels::vec_exp<<<1, 6>>>(d_input, d_output, 6);
cudaStatus = cudaDeviceSynchronize();
EXPECT_EQ(cudaStatus, cudaSuccess);
std::vector<float> output(6);
cudaStatus = cudaMemcpy(
output.data(), d_output, sizeof(float) * 6, cudaMemcpyDeviceToHost
);
EXPECT_EQ(cudaStatus, cudaSuccess);
for (int i = 0; i < 6; i++) {
EXPECT_NEAR(expected[i], output[i], 1e7);
}
cudaFree(d_input);
cudaFree(d_output);
cudaDeviceReset();
}
TEST(MatMulTest, SumReduceTest) {
cudaError_t cudaStatus;
const int n = 1 << 16;
std::vector<float> input(n);
for (int i = 0; i < n; i++) {
input[i] = 1.0f;
}
const float expected = n;
float* d_input = nullptr;
float* d_sum = nullptr;
const int gridSize = (n + BLOCK_SIZE - 1) / BLOCK_SIZE;
cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * n);
EXPECT_EQ(cudaStatus, cudaSuccess);
cudaStatus = cudaMalloc((void**)&d_sum, sizeof(float) * n);
EXPECT_EQ(cudaStatus, cudaSuccess);
cudaStatus =
cudaMemcpy(d_input, input.data(), sizeof(float) * n, cudaMemcpyHostToDevice);
EXPECT_EQ(cudaStatus, cudaSuccess);
CUDANet::Kernels::sum_reduce<<<gridSize, BLOCK_SIZE>>>(
d_input, d_sum, n
);
int remaining = gridSize;
while (remaining > 1) {
std::cout << remaining << std::endl;
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
CUDANet::Kernels::sum_reduce<<<blocks_needed, BLOCK_SIZE>>>(d_sum, d_sum, remaining);
remaining = blocks_needed;
}
std::vector<float> sum(n);
cudaStatus = cudaMemcpy(
sum.data(), d_sum, sizeof(float) * n, cudaMemcpyDeviceToHost
);
EXPECT_EQ(cudaStatus, cudaSuccess);
EXPECT_FLOAT_EQ(expected, sum[0]);
cudaFree(d_input);
cudaFree(d_sum);
cudaDeviceReset(); cudaDeviceReset();
} }