mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Add utils vector mean function
This commit is contained in:
@@ -38,6 +38,17 @@ void sum(const float *d_vec, float *d_sum, const unsigned int length);
|
||||
*/
|
||||
void max(const float *d_vec, float *d_max, const unsigned int length);
|
||||
|
||||
|
||||
/**
|
||||
* @brief Compute the mean of the vector
|
||||
*
|
||||
* @param d_vec Device pointer to the vector
|
||||
* @param d_mean Device pointer to the mean
|
||||
* @param d_length Device pointer to the length
|
||||
* @param length Length of the vector
|
||||
*/
|
||||
void mean(const float *d_vec, float *d_mean, float *d_length, int length);
|
||||
|
||||
// /**
|
||||
// * @brief Compute the variance of a vector
|
||||
// *
|
||||
|
||||
@@ -61,14 +61,16 @@ void Utils::sum(const float* d_vec, float* d_sum, const unsigned int length) {
|
||||
}
|
||||
}
|
||||
|
||||
// __device__ float Utils::mean(float* d_vec, const unsigned int length) {
|
||||
// float sum = 0;
|
||||
// for (int i = 0; i < length; ++i) {
|
||||
// sum += d_vec[i];
|
||||
// }
|
||||
void Utils::mean(const float* d_vec, float* d_mean, float *d_length, int length) {
|
||||
Utils::sum(d_vec, d_mean, length);
|
||||
|
||||
// void Utils::var(float* d_vec, float* d_mean, float* d_var, const unsigned int length) {
|
||||
const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_mean,
|
||||
d_mean,
|
||||
d_length,
|
||||
length
|
||||
);
|
||||
|
||||
// // TODO:
|
||||
|
||||
// }
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
39
test/utils/test_vector.cu
Normal file
39
test/utils/test_vector.cu
Normal file
@@ -0,0 +1,39 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "vector.cuh"
|
||||
|
||||
TEST(VectorTest, TestVectorMean) {
|
||||
|
||||
cudaError_t cudaStatus;
|
||||
float length = 10;
|
||||
|
||||
std::vector<float> input = {0.44371f, 0.20253f, 0.73232f, 0.40378f, 0.93348f, 0.72756f, 0.63388f, 0.5251f, 0.23973f, 0.52233f};
|
||||
|
||||
float* d_vec = nullptr;
|
||||
cudaStatus = cudaMalloc((void **)&d_vec, sizeof(float) * length);
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
float* d_mean = nullptr;
|
||||
cudaStatus = cudaMalloc((void **)&d_mean, sizeof(float) * length);
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
float* d_length = nullptr;
|
||||
cudaStatus = cudaMalloc((void **)&d_length, sizeof(float));
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
cudaStatus = cudaMemcpy(d_vec, input.data(), sizeof(float) * length, cudaMemcpyHostToDevice);
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
cudaStatus = cudaMemcpy(d_length, &length, sizeof(float), cudaMemcpyHostToDevice);
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
CUDANet::Utils::mean(d_vec, d_mean, d_length, length);
|
||||
|
||||
std::vector<float> mean(length);
|
||||
cudaStatus = cudaMemcpy(mean.data(), d_mean, sizeof(float) * length, cudaMemcpyDeviceToHost);
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
float expected_mean = 0.5364f;
|
||||
EXPECT_NEAR(mean[0], expected_mean, 1e-4);
|
||||
|
||||
}
|
||||
10
tools/vector_test.py
Normal file
10
tools/vector_test.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import torch
|
||||
|
||||
def gen_vector_mean_test_result():
|
||||
input = torch.tensor([0.44371, 0.20253, 0.73232, 0.40378, 0.93348, 0.72756, 0.63388, 0.5251, 0.23973, 0.52233])
|
||||
output = torch.mean(input)
|
||||
|
||||
print(output)
|
||||
|
||||
if __name__ == "__main__":
|
||||
gen_vector_mean_test_result()
|
||||
Reference in New Issue
Block a user