mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-05 17:34:21 +00:00
Move softmax partial kernels to matmul
This commit is contained in:
@@ -29,45 +29,6 @@ __global__ void relu(
|
|||||||
const unsigned int len
|
const unsigned int len
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Softmax activation exponentiation kernel
|
|
||||||
*
|
|
||||||
* @param src Pointer to the source array
|
|
||||||
* @param dst Pointer to the destination array
|
|
||||||
* @param len Length of the arrays
|
|
||||||
*/
|
|
||||||
__global__ void softmax_exp(
|
|
||||||
const float* __restrict__ src,
|
|
||||||
float* __restrict__ dst,
|
|
||||||
const unsigned int len
|
|
||||||
);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief
|
|
||||||
*
|
|
||||||
* @param d_vector Device pointer to vector
|
|
||||||
* @param d_output Device pointer to output vector
|
|
||||||
* @param w Length of the vector
|
|
||||||
*/
|
|
||||||
__global__ void softmax_sum(
|
|
||||||
const float* __restrict__ d_vector,
|
|
||||||
float* __restrict__ d_output
|
|
||||||
);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Softmax activation function kernel
|
|
||||||
*
|
|
||||||
* @param src Pointer to the source array
|
|
||||||
* @param dst Pointer to the destination array
|
|
||||||
* @param len Length of the arrays
|
|
||||||
*/
|
|
||||||
__global__ void softmax_div(
|
|
||||||
const float* __restrict__ src,
|
|
||||||
float* __restrict__ dst,
|
|
||||||
const float* __restrict__ sum,
|
|
||||||
const unsigned int len
|
|
||||||
);
|
|
||||||
|
|
||||||
} // namespace CUDANet::Kernels
|
} // namespace CUDANet::Kernels
|
||||||
|
|
||||||
#endif // CUDANET_ACTIVATION_FUNCTIONS_H
|
#endif // CUDANET_ACTIVATION_FUNCTIONS_H
|
||||||
@@ -35,17 +35,6 @@ __global__ void vec_vec_add(
|
|||||||
const unsigned int w
|
const unsigned int w
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Max reduction kernel
|
|
||||||
*
|
|
||||||
* @param d_vector Device pointer to vector
|
|
||||||
* @param d_output Device pointer to output vector
|
|
||||||
*/
|
|
||||||
__global__ void max_reduce(
|
|
||||||
const float* __restrict__ d_vector,
|
|
||||||
float* __restrict__ d_output
|
|
||||||
);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Add scalar to each element of the vector
|
* @brief Add scalar to each element of the vector
|
||||||
*
|
*
|
||||||
@@ -56,15 +45,68 @@ __global__ void max_reduce(
|
|||||||
* @return __global__
|
* @return __global__
|
||||||
*/
|
*/
|
||||||
__global__ void vec_scalar_sub(
|
__global__ void vec_scalar_sub(
|
||||||
const float* __restrict__ d_vector,
|
const float* __restrict__ d_src,
|
||||||
|
float* __restrict__ d_out,
|
||||||
const float* __restrict__ d_scalar,
|
const float* __restrict__ d_scalar,
|
||||||
float* __restrict__ d_output,
|
const unsigned int len
|
||||||
const unsigned int w
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Softmax activation function kernel
|
||||||
|
*
|
||||||
|
* @param src Pointer to the source array
|
||||||
|
* @param dst Pointer to the destination array
|
||||||
|
* @param len Length of the arrays
|
||||||
|
*/
|
||||||
|
__global__ void vec_scalar_div(
|
||||||
|
const float* __restrict__ d_src,
|
||||||
|
float* __restrict__ d_out,
|
||||||
|
const float* __restrict__ d_scalar,
|
||||||
|
const unsigned int len
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Softmax activation exponentiation kernel
|
||||||
|
*
|
||||||
|
* @param src Pointer to the source array
|
||||||
|
* @param dst Pointer to the destination array
|
||||||
|
* @param len Length of the arrays
|
||||||
|
*/
|
||||||
|
__global__ void vec_exp(
|
||||||
|
const float* __restrict__ src,
|
||||||
|
float* __restrict__ dst,
|
||||||
|
const unsigned int len
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Max reduction kernel
|
||||||
|
*
|
||||||
|
* @param d_vector Device pointer to vector
|
||||||
|
* @param d_output Device pointer to output vector
|
||||||
|
*/
|
||||||
|
__global__ void max_reduce(
|
||||||
|
const float* __restrict__ d_vector,
|
||||||
|
float* __restrict__ d_output,
|
||||||
|
const unsigned int len
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief
|
||||||
|
*
|
||||||
|
* @param d_vector Device pointer to vector
|
||||||
|
* @param d_output Device pointer to output vector
|
||||||
|
* @param len Length of the vector
|
||||||
|
*/
|
||||||
|
__global__ void sum_reduce(
|
||||||
|
const float* __restrict__ d_vector,
|
||||||
|
float* __restrict__ d_output,
|
||||||
|
const unsigned int len
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
__global__ void clear(
|
__global__ void clear(
|
||||||
float* __restrict__ d_vector,
|
float* __restrict__ d_vector,
|
||||||
const unsigned int w
|
const unsigned int len
|
||||||
);
|
);
|
||||||
|
|
||||||
} // namespace CUDANet::Kernels
|
} // namespace CUDANet::Kernels
|
||||||
|
|||||||
@@ -28,51 +28,3 @@ __global__ void Kernels::relu(
|
|||||||
dst[i] = src[i] < 0.0 ? 0.0 : src[i];
|
dst[i] = src[i] < 0.0 ? 0.0 : src[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Kernels::softmax_exp(
|
|
||||||
const float* __restrict__ src,
|
|
||||||
float* __restrict__ dst,
|
|
||||||
const unsigned int len
|
|
||||||
) {
|
|
||||||
int stride = gridDim.x * blockDim.x;
|
|
||||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
|
||||||
|
|
||||||
for (int i = tid; i < len; i += stride) {
|
|
||||||
dst[i] = expf(src[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__global__ void Kernels::softmax_sum(
|
|
||||||
const float* __restrict__ d_vector,
|
|
||||||
float* __restrict__ d_output
|
|
||||||
) {
|
|
||||||
__shared__ float partial_sum[BLOCK_SIZE];
|
|
||||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
partial_sum[threadIdx.x] = d_vector[i];
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
|
||||||
if (threadIdx.x < s) {
|
|
||||||
partial_sum[threadIdx.x] += partial_sum[threadIdx.x + s];
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
d_output[blockIdx.x] = partial_sum[0];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__global__ void Kernels::softmax_div(
|
|
||||||
const float* __restrict__ src,
|
|
||||||
float* __restrict__ dst,
|
|
||||||
const float* __restrict__ sum,
|
|
||||||
const unsigned int len
|
|
||||||
) {
|
|
||||||
int stride = gridDim.x * blockDim.x;
|
|
||||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
|
||||||
|
|
||||||
for (int i = tid; i < len; i += stride) {
|
|
||||||
dst[i] = src[i] / sum[0];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
using namespace CUDANet;
|
using namespace CUDANet;
|
||||||
|
|
||||||
|
|
||||||
__global__ void Kernels::mat_vec_mul(
|
__global__ void Kernels::mat_vec_mul(
|
||||||
const float* __restrict__ d_matrix,
|
const float* __restrict__ d_matrix,
|
||||||
const float* __restrict__ d_vector,
|
const float* __restrict__ d_vector,
|
||||||
@@ -37,6 +38,7 @@ __global__ void Kernels::mat_vec_mul(
|
|||||||
d_output[tid] = temp;
|
d_output[tid] = temp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
__global__ void Kernels::vec_vec_add(
|
__global__ void Kernels::vec_vec_add(
|
||||||
const float* __restrict__ d_vector1,
|
const float* __restrict__ d_vector1,
|
||||||
const float* __restrict__ d_vector2,
|
const float* __restrict__ d_vector2,
|
||||||
@@ -50,14 +52,75 @@ __global__ void Kernels::vec_vec_add(
|
|||||||
d_output[tid] = d_vector1[tid] + d_vector2[tid];
|
d_output[tid] = d_vector1[tid] + d_vector2[tid];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void Kernels::vec_scalar_sub(
|
||||||
|
const float* __restrict__ d_src,
|
||||||
|
float* __restrict__ d_out,
|
||||||
|
const float* __restrict__ d_scalar,
|
||||||
|
const unsigned int len
|
||||||
|
) {
|
||||||
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
if (tid >= len) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
d_out[tid] = d_src[tid] - d_scalar[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void Kernels::vec_scalar_div(
|
||||||
|
const float* __restrict__ d_src,
|
||||||
|
float* __restrict__ d_out,
|
||||||
|
const float* __restrict__ d_scalar,
|
||||||
|
const unsigned int len
|
||||||
|
) {
|
||||||
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
if (tid >= len) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
d_out[tid] = d_src[tid] / d_scalar[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void Kernels::vec_exp(
|
||||||
|
const float* __restrict__ src,
|
||||||
|
float* __restrict__ dst,
|
||||||
|
const unsigned int len
|
||||||
|
) {
|
||||||
|
int stride = gridDim.x * blockDim.x;
|
||||||
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
for (int i = tid; i < len; i += stride) {
|
||||||
|
dst[i] = expf(src[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void Kernels::clear(
|
||||||
|
float* __restrict__ d_vector,
|
||||||
|
const unsigned int w
|
||||||
|
) {
|
||||||
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
if (tid >= w) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
d_vector[tid] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
__global__ void Kernels::max_reduce(
|
__global__ void Kernels::max_reduce(
|
||||||
const float* __restrict__ d_vector,
|
const float* __restrict__ d_vector,
|
||||||
float* __restrict__ d_output
|
float* __restrict__ d_output,
|
||||||
|
const unsigned int len
|
||||||
) {
|
) {
|
||||||
__shared__ float shared_max[BLOCK_SIZE];
|
__shared__ float shared_max[BLOCK_SIZE];
|
||||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
shared_max[threadIdx.x] = d_vector[i];
|
if (i < len) {
|
||||||
|
shared_max[threadIdx.x] = d_vector[i];
|
||||||
|
} else {
|
||||||
|
shared_max[threadIdx.x] = -INFINITY;
|
||||||
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||||
@@ -72,26 +135,30 @@ __global__ void Kernels::max_reduce(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Kernels::vec_scalar_sub(
|
__global__ void Kernels::sum_reduce(
|
||||||
const float* __restrict__ d_vector,
|
const float* __restrict__ d_vector,
|
||||||
const float* __restrict__ d_scalar,
|
|
||||||
float* __restrict__ d_output,
|
float* __restrict__ d_output,
|
||||||
const unsigned int w
|
const unsigned int len
|
||||||
) {
|
) {
|
||||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
__shared__ float partial_sum[BLOCK_SIZE];
|
||||||
if (tid >= w) {
|
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
return;
|
|
||||||
}
|
|
||||||
d_output[tid] = d_vector[tid] - d_scalar[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
__global__ void Kernels::clear(
|
if (i < len) {
|
||||||
float* __restrict__ d_vector,
|
partial_sum[threadIdx.x] = d_vector[i];
|
||||||
const unsigned int w
|
} else {
|
||||||
) {
|
partial_sum[threadIdx.x] = 0.0f;
|
||||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
}
|
||||||
if (tid >= w) {
|
|
||||||
return;
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||||
|
if (threadIdx.x < s) {
|
||||||
|
partial_sum[threadIdx.x] += partial_sum[threadIdx.x + s];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
d_output[blockIdx.x] = partial_sum[0];
|
||||||
}
|
}
|
||||||
d_vector[tid] = 0.0f;
|
|
||||||
}
|
}
|
||||||
@@ -4,6 +4,7 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
#include "activation_functions.cuh"
|
#include "activation_functions.cuh"
|
||||||
|
#include "matmul.cuh"
|
||||||
#include "cuda_helper.cuh"
|
#include "cuda_helper.cuh"
|
||||||
|
|
||||||
TEST(ActivationFunctionsTest, SigmoidSanityCheck) {
|
TEST(ActivationFunctionsTest, SigmoidSanityCheck) {
|
||||||
@@ -43,93 +44,24 @@ TEST(ActivationFunctionsTest, SigmoidSanityCheck) {
|
|||||||
|
|
||||||
cudaFree(d_input);
|
cudaFree(d_input);
|
||||||
cudaFree(d_output);
|
cudaFree(d_output);
|
||||||
|
|
||||||
|
cudaDeviceReset();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ActivationFunctionsTest, SoftmaxExpTest) {
|
// void print_vec(float* d_vec, int length) {
|
||||||
cudaError_t cudaStatus;
|
|
||||||
|
|
||||||
float input[6] = {22.496f, 36.9006f, 30.9904f,
|
// std::vector<float> h_vec(length);
|
||||||
28.4213f, 26.4541f, 31.7887f};
|
// CUDA_CHECK(cudaMemcpy(
|
||||||
|
// h_vec.data(), d_vec, sizeof(float) * length, cudaMemcpyDeviceToHost
|
||||||
|
// ));
|
||||||
|
|
||||||
std::vector<float> expected = {5886928896.0f, 1.06102872080384e+16f,
|
// float sum = 0.0f;
|
||||||
28771323215872.0f, 2204012904448.0f,
|
|
||||||
308226162688.0f, 63922983927808.0f};
|
|
||||||
|
|
||||||
float* d_input;
|
// for (int i = 0; i < length; ++i) {
|
||||||
float* d_output;
|
// std::cout << h_vec[i] << ", ";
|
||||||
|
// sum += h_vec[i];
|
||||||
|
// }
|
||||||
|
|
||||||
cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * 6);
|
// std::cout << std::endl;
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
|
||||||
|
|
||||||
cudaStatus = cudaMalloc((void**)&d_output, sizeof(float) * 6);
|
// }
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
|
||||||
|
|
||||||
cudaStatus =
|
|
||||||
cudaMemcpy(d_input, input, sizeof(float) * 6, cudaMemcpyHostToDevice);
|
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
|
||||||
|
|
||||||
CUDANet::Kernels::softmax_exp<<<1, 6>>>(d_input, d_output, 6);
|
|
||||||
cudaStatus = cudaDeviceSynchronize();
|
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
|
||||||
|
|
||||||
std::vector<float> output(6);
|
|
||||||
|
|
||||||
cudaStatus = cudaMemcpy(
|
|
||||||
output.data(), d_output, sizeof(float) * 6, cudaMemcpyDeviceToHost
|
|
||||||
);
|
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
|
||||||
|
|
||||||
for (int i = 0; i < 6; i++) {
|
|
||||||
EXPECT_NEAR(expected[i], output[i], 1e7);
|
|
||||||
}
|
|
||||||
|
|
||||||
cudaFree(d_input);
|
|
||||||
cudaFree(d_output);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ActivationFunctionsTest, SoftmaxSumTest) {
|
|
||||||
cudaError_t cudaStatus;
|
|
||||||
|
|
||||||
const int n = 10;
|
|
||||||
std::vector<float> input(n);
|
|
||||||
for (int i = 0; i < n; i++) {
|
|
||||||
input[i] = i;
|
|
||||||
}
|
|
||||||
|
|
||||||
const float expected = n * (n - 1) / 2;
|
|
||||||
|
|
||||||
float* d_input;
|
|
||||||
float* d_sum;
|
|
||||||
|
|
||||||
const int gridSize = (n + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
|
||||||
|
|
||||||
cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * n);
|
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
|
||||||
|
|
||||||
cudaStatus = cudaMalloc((void**)&d_sum, sizeof(float) * n);
|
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
|
||||||
|
|
||||||
cudaStatus =
|
|
||||||
cudaMemcpy(d_input, input.data(), sizeof(float) * n, cudaMemcpyHostToDevice);
|
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
|
||||||
|
|
||||||
CUDANet::Kernels::softmax_sum<<<gridSize, BLOCK_SIZE>>>(
|
|
||||||
d_input, d_sum
|
|
||||||
);
|
|
||||||
|
|
||||||
CUDANet::Kernels::softmax_sum<<<1, BLOCK_SIZE>>>(
|
|
||||||
d_sum, d_sum
|
|
||||||
);
|
|
||||||
|
|
||||||
CUDANet::Kernels::softmax_sum<<<1, BLOCK_SIZE>>>(
|
|
||||||
d_sum, d_sum
|
|
||||||
);
|
|
||||||
|
|
||||||
std::vector<float> sum(n);
|
|
||||||
cudaStatus = cudaMemcpy(
|
|
||||||
sum.data(), d_sum, sizeof(float) * n, cudaMemcpyDeviceToHost
|
|
||||||
);
|
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
|
||||||
|
|
||||||
EXPECT_FLOAT_EQ(expected, sum[0]);
|
|
||||||
}
|
|
||||||
@@ -73,33 +73,141 @@ TEST(MatMulTest, MatVecMulTest) {
|
|||||||
TEST(MatMulTest, MaxReduceTest) {
|
TEST(MatMulTest, MaxReduceTest) {
|
||||||
cudaError_t cudaStatus;
|
cudaError_t cudaStatus;
|
||||||
|
|
||||||
std::vector<float> input = {0.643f, 0.912f, 0.723f, 0.587f, 0.155f, 0.932f, 0.391f, 0.279f, 0.846f, 0.788f};
|
const int n = 1 << 16;
|
||||||
|
|
||||||
|
std::vector<float> input(n);
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
input[i] = i;
|
||||||
|
}
|
||||||
|
|
||||||
float* d_input;
|
float* d_input;
|
||||||
float* d_output;
|
float* d_output;
|
||||||
|
|
||||||
cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * 10);
|
cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * n);
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
cudaStatus = cudaMalloc((void**)&d_output, sizeof(float));
|
cudaStatus = cudaMalloc((void**)&d_output, sizeof(float));
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
cudaStatus = cudaMemcpy(d_input, input.data(), sizeof(float) * 10, cudaMemcpyHostToDevice);
|
cudaStatus = cudaMemcpy(d_input, input.data(), sizeof(float) * n, cudaMemcpyHostToDevice);
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
const int grid_size = (10 + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
const int grid_size = (n + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
|
|
||||||
CUDANet::Kernels::max_reduce<<<grid_size, BLOCK_SIZE>>>(d_input, d_output);
|
CUDANet::Kernels::max_reduce<<<grid_size, BLOCK_SIZE>>>(d_input, d_output, n);
|
||||||
CUDANet::Kernels::max_reduce<<<1, BLOCK_SIZE>>>(d_output, d_output);
|
|
||||||
|
|
||||||
std::vector<float> output(10);
|
int remaining = grid_size;
|
||||||
|
while (remaining > 1) {
|
||||||
|
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
|
CUDANet::Kernels::max_reduce<<<blocks_needed, BLOCK_SIZE>>>(d_output, d_output, remaining);
|
||||||
|
remaining = blocks_needed;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> output(n);
|
||||||
cudaStatus = cudaMemcpy(output.data(), d_output, sizeof(float), cudaMemcpyDeviceToHost);
|
cudaStatus = cudaMemcpy(output.data(), d_output, sizeof(float), cudaMemcpyDeviceToHost);
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
EXPECT_EQ(output[0], 0.932f);
|
EXPECT_EQ(output[0], 65535.0f);
|
||||||
|
|
||||||
cudaFree(d_input);
|
cudaFree(d_input);
|
||||||
cudaFree(d_output);
|
cudaFree(d_output);
|
||||||
|
|
||||||
|
cudaDeviceReset();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MatMulTest, VecExpTest) {
|
||||||
|
cudaError_t cudaStatus;
|
||||||
|
|
||||||
|
float input[6] = {22.496f, 36.9006f, 30.9904f,
|
||||||
|
28.4213f, 26.4541f, 31.7887f};
|
||||||
|
|
||||||
|
std::vector<float> expected = {5886928896.0f, 1.06102872080384e+16f,
|
||||||
|
28771323215872.0f, 2204012904448.0f,
|
||||||
|
308226162688.0f, 63922983927808.0f};
|
||||||
|
|
||||||
|
float* d_input;
|
||||||
|
float* d_output;
|
||||||
|
|
||||||
|
cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * 6);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
cudaStatus = cudaMalloc((void**)&d_output, sizeof(float) * 6);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
cudaStatus =
|
||||||
|
cudaMemcpy(d_input, input, sizeof(float) * 6, cudaMemcpyHostToDevice);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
CUDANet::Kernels::vec_exp<<<1, 6>>>(d_input, d_output, 6);
|
||||||
|
cudaStatus = cudaDeviceSynchronize();
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
std::vector<float> output(6);
|
||||||
|
|
||||||
|
cudaStatus = cudaMemcpy(
|
||||||
|
output.data(), d_output, sizeof(float) * 6, cudaMemcpyDeviceToHost
|
||||||
|
);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
for (int i = 0; i < 6; i++) {
|
||||||
|
EXPECT_NEAR(expected[i], output[i], 1e7);
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaFree(d_input);
|
||||||
|
cudaFree(d_output);
|
||||||
|
|
||||||
|
cudaDeviceReset();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MatMulTest, SumReduceTest) {
|
||||||
|
cudaError_t cudaStatus;
|
||||||
|
|
||||||
|
const int n = 1 << 16;
|
||||||
|
|
||||||
|
std::vector<float> input(n);
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
input[i] = 1.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float expected = n;
|
||||||
|
|
||||||
|
float* d_input = nullptr;
|
||||||
|
float* d_sum = nullptr;
|
||||||
|
|
||||||
|
const int gridSize = (n + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
|
|
||||||
|
cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * n);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
cudaStatus = cudaMalloc((void**)&d_sum, sizeof(float) * n);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
cudaStatus =
|
||||||
|
cudaMemcpy(d_input, input.data(), sizeof(float) * n, cudaMemcpyHostToDevice);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
CUDANet::Kernels::sum_reduce<<<gridSize, BLOCK_SIZE>>>(
|
||||||
|
d_input, d_sum, n
|
||||||
|
);
|
||||||
|
|
||||||
|
int remaining = gridSize;
|
||||||
|
while (remaining > 1) {
|
||||||
|
std::cout << remaining << std::endl;
|
||||||
|
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
|
CUDANet::Kernels::sum_reduce<<<blocks_needed, BLOCK_SIZE>>>(d_sum, d_sum, remaining);
|
||||||
|
remaining = blocks_needed;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> sum(n);
|
||||||
|
cudaStatus = cudaMemcpy(
|
||||||
|
sum.data(), d_sum, sizeof(float) * n, cudaMemcpyDeviceToHost
|
||||||
|
);
|
||||||
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
EXPECT_FLOAT_EQ(expected, sum[0]);
|
||||||
|
|
||||||
|
cudaFree(d_input);
|
||||||
|
cudaFree(d_sum);
|
||||||
|
|
||||||
cudaDeviceReset();
|
cudaDeviceReset();
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user