Implement device vector utils

This commit is contained in:
2024-04-11 22:22:33 +02:00
parent 710a33bdde
commit 4b9d123e94
6 changed files with 109 additions and 19 deletions

View File

@@ -103,12 +103,6 @@ __global__ void sum_reduce(
const unsigned int len const unsigned int len
); );
__global__ void clear(
float* __restrict__ d_vector,
const unsigned int len
);
} // namespace CUDANet::Kernels } // namespace CUDANet::Kernels
#endif // CUDANET_MATMUL_H #endif // CUDANET_MATMUL_H

44
include/utils/vector.cuh Normal file
View File

@@ -0,0 +1,44 @@
#ifndef CUDANET_VECTOR_H
#define CUDANET_VECTOR_H
namespace CUDANet::Utils {
/**
* @brief Utility function that prints a vector
*
* @param d_vec Pointer to the vector on device
* @param length Length of the vector
*/
void print_vec(float *d_vec, const unsigned int length);
/**
* @brief Utility function that clears a vector
*
* @param d_vector Pointer to the vector on device
* @param len Length of the vector
*/
void clear(float *d_vector, const unsigned int len);
/**
* @brief Utility function that returns the sum of a vector
*
* @param d_vec Pointer to the vector
* @param length Length of the vector
*/
void sum(float *d_vec, float *d_sum, const unsigned int length);
/**
* @brief Utility function that returns the max of a vector
*
* @param d_vec Pointer to the vector
* @param length Length of the vector
*/
void max(float *d_vec, float *d_max, const unsigned int length);
} // namespace CUDANet::Utils
#endif // CUDANET_VECTOR_H

View File

@@ -95,18 +95,6 @@ __global__ void Kernels::vec_exp(
} }
__global__ void Kernels::clear(
float* __restrict__ d_vector,
const unsigned int w
) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= w) {
return;
}
d_vector[tid] = 0.0f;
}
__global__ void Kernels::max_reduce( __global__ void Kernels::max_reduce(
const float* __restrict__ d_vector, const float* __restrict__ d_vector,
float* __restrict__ d_output, float* __restrict__ d_output,

View File

@@ -5,6 +5,7 @@
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include "vector.cuh"
#include "activation.cuh" #include "activation.cuh"
#include "cuda_helper.cuh" #include "cuda_helper.cuh"
#include "dense.cuh" #include "dense.cuh"
@@ -63,6 +64,9 @@ void Dense::initializeBiases() {
} }
float* Dense::forward(const float* d_input) { float* Dense::forward(const float* d_input) {
CUDANet::Utils::clear(d_output, outputSize);
Kernels::mat_vec_mul<<<forwardGridSize, BLOCK_SIZE>>>( Kernels::mat_vec_mul<<<forwardGridSize, BLOCK_SIZE>>>(
d_weights, d_input, d_output, inputSize, outputSize d_weights, d_input, d_output, inputSize, outputSize
); );

58
src/utils/vector.cu Normal file
View File

@@ -0,0 +1,58 @@
#include <iostream>
#include <vector>
#include "vector.cuh"
#include "matmul.cuh"
#include "cuda_helper.cuh"
using namespace CUDANet;
void Utils::print_vec(float* d_vec, const unsigned int length) {
std::vector<float> h_vec(length);
CUDA_CHECK(cudaMemcpy(
h_vec.data(), d_vec, sizeof(float) * length, cudaMemcpyDeviceToHost
));
for (int i = 0; i < length; ++i) {
std::cout << h_vec[i] << ", ";
}
std::cout << std::endl;
}
void Utils::clear(float* d_vec, const unsigned int length) {
CUDA_CHECK(cudaMemset(d_vec, 0, sizeof(float) * length));
}
void Utils::max(float* d_vec, float* d_max, const unsigned int length) {
const int grid_size = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
Kernels::max_reduce<<<grid_size, BLOCK_SIZE>>>(d_vec, d_max, length);
int remaining = grid_size;
while (remaining > 1) {
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
CUDANet::Kernels::max_reduce<<<blocks_needed, BLOCK_SIZE>>>(d_max, d_max, remaining);
remaining = blocks_needed;
}
}
void Utils::sum(float* d_vec, float* d_sum, const unsigned int length) {
const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
CUDANet::Kernels::sum_reduce<<<gridSize, BLOCK_SIZE>>>(
d_vec, d_sum, length
);
int remaining = gridSize;
while (remaining > 1) {
std::cout << remaining << std::endl;
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
CUDANet::Kernels::sum_reduce<<<blocks_needed, BLOCK_SIZE>>>(d_sum, d_sum, remaining);
remaining = blocks_needed;
}
}

View File

@@ -4,6 +4,7 @@
#include <vector> #include <vector>
#include "cuda_helper.cuh" #include "cuda_helper.cuh"
#include "vector.cuh"
#include "matmul.cuh" #include "matmul.cuh"
TEST(MatMulTest, MatVecMulTest) { TEST(MatMulTest, MatVecMulTest) {
@@ -45,7 +46,7 @@ TEST(MatMulTest, MatVecMulTest) {
int THREADS_PER_BLOCK = std::max(w, h); int THREADS_PER_BLOCK = std::max(w, h);
int BLOCKS = 1; int BLOCKS = 1;
CUDANet::Kernels::clear<<<BLOCKS, h>>>(d_output, h); CUDANet::Utils::clear(d_output, h);
CUDANet::Kernels::mat_vec_mul<<<BLOCKS, THREADS_PER_BLOCK, sizeof(float) * w>>>(d_matrix, d_vector, d_output, w, h); CUDANet::Kernels::mat_vec_mul<<<BLOCKS, THREADS_PER_BLOCK, sizeof(float) * w>>>(d_matrix, d_vector, d_output, w, h);
cudaStatus = cudaDeviceSynchronize(); cudaStatus = cudaDeviceSynchronize();
@@ -198,6 +199,7 @@ TEST(MatMulTest, SumReduceTest) {
remaining = blocks_needed; remaining = blocks_needed;
} }
std::vector<float> sum(n); std::vector<float> sum(n);
cudaStatus = cudaMemcpy( cudaStatus = cudaMemcpy(
sum.data(), d_sum, sizeof(float) * n, cudaMemcpyDeviceToHost sum.data(), d_sum, sizeof(float) * n, cudaMemcpyDeviceToHost