mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-05 17:34:21 +00:00
Implement device vector utils
This commit is contained in:
@@ -103,12 +103,6 @@ __global__ void sum_reduce(
|
||||
const unsigned int len
|
||||
);
|
||||
|
||||
|
||||
__global__ void clear(
|
||||
float* __restrict__ d_vector,
|
||||
const unsigned int len
|
||||
);
|
||||
|
||||
} // namespace CUDANet::Kernels
|
||||
|
||||
#endif // CUDANET_MATMUL_H
|
||||
44
include/utils/vector.cuh
Normal file
44
include/utils/vector.cuh
Normal file
@@ -0,0 +1,44 @@
|
||||
#ifndef CUDANET_VECTOR_H
|
||||
#define CUDANET_VECTOR_H
|
||||
|
||||
namespace CUDANet::Utils {
|
||||
|
||||
|
||||
/**
|
||||
* @brief Utility function that prints a vector
|
||||
*
|
||||
* @param d_vec Pointer to the vector on device
|
||||
* @param length Length of the vector
|
||||
*/
|
||||
void print_vec(float *d_vec, const unsigned int length);
|
||||
|
||||
/**
|
||||
* @brief Utility function that clears a vector
|
||||
*
|
||||
* @param d_vector Pointer to the vector on device
|
||||
* @param len Length of the vector
|
||||
*/
|
||||
void clear(float *d_vector, const unsigned int len);
|
||||
|
||||
|
||||
/**
|
||||
* @brief Utility function that returns the sum of a vector
|
||||
*
|
||||
* @param d_vec Pointer to the vector
|
||||
* @param length Length of the vector
|
||||
*/
|
||||
void sum(float *d_vec, float *d_sum, const unsigned int length);
|
||||
|
||||
|
||||
/**
|
||||
* @brief Utility function that returns the max of a vector
|
||||
*
|
||||
* @param d_vec Pointer to the vector
|
||||
* @param length Length of the vector
|
||||
*/
|
||||
void max(float *d_vec, float *d_max, const unsigned int length);
|
||||
|
||||
|
||||
} // namespace CUDANet::Utils
|
||||
|
||||
#endif // CUDANET_VECTOR_H
|
||||
Reference in New Issue
Block a user