mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Set up basic tests with gtest
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include "cuda_helper.h"
|
#include "cuda_helper.h"
|
||||||
|
#include <cuda_runtime.h>
|
||||||
#include <cublas_v2.h>
|
#include <cublas_v2.h>
|
||||||
|
|
||||||
cudaDeviceProp initializeCUDA(cublasHandle_t& cublasHandle) {
|
cudaDeviceProp initializeCUDA(cublasHandle_t& cublasHandle) {
|
||||||
|
|||||||
14
test/CMakeLists.txt
Normal file
14
test/CMakeLists.txt
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
find_package(GTest REQUIRED)
|
||||||
|
include_directories(${GTEST_INCLUDE_DIRS})
|
||||||
|
|
||||||
|
add_executable(test_dense layers/test_dense.cpp)
|
||||||
|
|
||||||
|
add_library(test_utils
|
||||||
|
test_utils/test_cublas_fixture.cpp
|
||||||
|
)
|
||||||
|
|
||||||
|
target_include_directories(test_utils PUBLIC test_utils)
|
||||||
|
|
||||||
|
target_link_libraries(test_dense gtest gtest_main CUDANet test_utils)
|
||||||
|
|
||||||
|
add_test(NAME TestDense COMMAND test_dense)
|
||||||
24
test/layers/test_dense.cpp
Normal file
24
test/layers/test_dense.cpp
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
#include "gtest/gtest.h"
|
||||||
|
#include "dense.h"
|
||||||
|
#include "test_cublas_fixture.h"
|
||||||
|
|
||||||
|
class DenseLayerTest : public CublasTestFixture {
|
||||||
|
protected:
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DenseLayerTest, Forward) {
|
||||||
|
|
||||||
|
Layers::Dense denseLayer(3, 2, cublasHandle);
|
||||||
|
|
||||||
|
// Create input and output arrays
|
||||||
|
float input[3] = {1.0f, 2.0f, 3.0f};
|
||||||
|
float output[2] = {0.0f, 0.0f};
|
||||||
|
|
||||||
|
// Perform forward pass
|
||||||
|
denseLayer.forward(input, output);
|
||||||
|
|
||||||
|
// Check if the output is a zero vector
|
||||||
|
EXPECT_FLOAT_EQ(output[0], 0.0f);
|
||||||
|
EXPECT_FLOAT_EQ(output[1], 0.0f);
|
||||||
|
}
|
||||||
13
test/test_utils/test_cublas_fixture.cpp
Normal file
13
test/test_utils/test_cublas_fixture.cpp
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
#include "gtest/gtest.h"
|
||||||
|
#include "cublas_v2.h"
|
||||||
|
#include "test_cublas_fixture.h"
|
||||||
|
|
||||||
|
cublasHandle_t CublasTestFixture::cublasHandle;
|
||||||
|
|
||||||
|
void CublasTestFixture::SetUpTestSuite() {
|
||||||
|
cublasCreate(&cublasHandle);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CublasTestFixture::TearDownTestSuite() {
|
||||||
|
cublasDestroy(cublasHandle);
|
||||||
|
}
|
||||||
10
test/test_utils/test_cublas_fixture.h
Normal file
10
test/test_utils/test_cublas_fixture.h
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
#include "gtest/gtest.h"
|
||||||
|
#include "cublas_v2.h"
|
||||||
|
|
||||||
|
class CublasTestFixture : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
static cublasHandle_t cublasHandle;
|
||||||
|
|
||||||
|
static void SetUpTestSuite();
|
||||||
|
static void TearDownTestSuite();
|
||||||
|
};
|
||||||
Reference in New Issue
Block a user