Implement device vector utils

This commit is contained in:
2024-04-11 22:22:33 +02:00
parent 710a33bdde
commit 4b9d123e94
6 changed files with 109 additions and 19 deletions

View File

@@ -103,12 +103,6 @@ __global__ void sum_reduce(
const unsigned int len
);
__global__ void clear(
float* __restrict__ d_vector,
const unsigned int len
);
} // namespace CUDANet::Kernels
#endif // CUDANET_MATMUL_H

44
include/utils/vector.cuh Normal file
View File

@@ -0,0 +1,44 @@
#ifndef CUDANET_VECTOR_H
#define CUDANET_VECTOR_H
namespace CUDANet::Utils {
/**
* @brief Utility function that prints a vector
*
* @param d_vec Pointer to the vector on device
* @param length Length of the vector
*/
void print_vec(float *d_vec, const unsigned int length);
/**
* @brief Utility function that clears a vector
*
* @param d_vector Pointer to the vector on device
* @param len Length of the vector
*/
void clear(float *d_vector, const unsigned int len);
/**
* @brief Utility function that returns the sum of a vector
*
* @param d_vec Pointer to the vector
* @param length Length of the vector
*/
void sum(float *d_vec, float *d_sum, const unsigned int length);
/**
* @brief Utility function that returns the max of a vector
*
* @param d_vec Pointer to the vector
* @param length Length of the vector
*/
void max(float *d_vec, float *d_max, const unsigned int length);
} // namespace CUDANet::Utils
#endif // CUDANET_VECTOR_H