From ee1a8cc6e6c5bb1c7045b643febdac04c1e821e6 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sat, 17 Feb 2024 23:07:26 +0100 Subject: [PATCH] Set up basic tests with gtest --- src/utils/cuda_helper.cpp | 1 + test/CMakeLists.txt | 14 ++++++++++++++ test/layers/test_dense.cpp | 24 ++++++++++++++++++++++++ test/test_utils/test_cublas_fixture.cpp | 13 +++++++++++++ test/test_utils/test_cublas_fixture.h | 10 ++++++++++ 5 files changed, 62 insertions(+) create mode 100644 test/CMakeLists.txt create mode 100644 test/layers/test_dense.cpp create mode 100644 test/test_utils/test_cublas_fixture.cpp create mode 100644 test/test_utils/test_cublas_fixture.h diff --git a/src/utils/cuda_helper.cpp b/src/utils/cuda_helper.cpp index 396fe50..9c6e5b8 100644 --- a/src/utils/cuda_helper.cpp +++ b/src/utils/cuda_helper.cpp @@ -1,6 +1,7 @@ #include #include #include "cuda_helper.h" +#include #include cudaDeviceProp initializeCUDA(cublasHandle_t& cublasHandle) { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt new file mode 100644 index 0000000..ccdf4b9 --- /dev/null +++ b/test/CMakeLists.txt @@ -0,0 +1,14 @@ +find_package(GTest REQUIRED) +include_directories(${GTEST_INCLUDE_DIRS}) + +add_executable(test_dense layers/test_dense.cpp) + +add_library(test_utils + test_utils/test_cublas_fixture.cpp +) + +target_include_directories(test_utils PUBLIC test_utils) + +target_link_libraries(test_dense gtest gtest_main CUDANet test_utils) + +add_test(NAME TestDense COMMAND test_dense) \ No newline at end of file diff --git a/test/layers/test_dense.cpp b/test/layers/test_dense.cpp new file mode 100644 index 0000000..31b3d14 --- /dev/null +++ b/test/layers/test_dense.cpp @@ -0,0 +1,24 @@ +#include "gtest/gtest.h" +#include "dense.h" +#include "test_cublas_fixture.h" + +class DenseLayerTest : public CublasTestFixture { +protected: +}; + + +TEST_F(DenseLayerTest, Forward) { + + Layers::Dense denseLayer(3, 2, cublasHandle); + + // Create input and output arrays + float input[3] = {1.0f, 2.0f, 3.0f}; + float output[2] = {0.0f, 0.0f}; + + // Perform forward pass + denseLayer.forward(input, output); + + // Check if the output is a zero vector + EXPECT_FLOAT_EQ(output[0], 0.0f); + EXPECT_FLOAT_EQ(output[1], 0.0f); +} diff --git a/test/test_utils/test_cublas_fixture.cpp b/test/test_utils/test_cublas_fixture.cpp new file mode 100644 index 0000000..ad1a604 --- /dev/null +++ b/test/test_utils/test_cublas_fixture.cpp @@ -0,0 +1,13 @@ +#include "gtest/gtest.h" +#include "cublas_v2.h" +#include "test_cublas_fixture.h" + +cublasHandle_t CublasTestFixture::cublasHandle; + +void CublasTestFixture::SetUpTestSuite() { + cublasCreate(&cublasHandle); +} + +void CublasTestFixture::TearDownTestSuite() { + cublasDestroy(cublasHandle); +} diff --git a/test/test_utils/test_cublas_fixture.h b/test/test_utils/test_cublas_fixture.h new file mode 100644 index 0000000..9e927df --- /dev/null +++ b/test/test_utils/test_cublas_fixture.h @@ -0,0 +1,10 @@ +#include "gtest/gtest.h" +#include "cublas_v2.h" + +class CublasTestFixture : public ::testing::Test { +protected: + static cublasHandle_t cublasHandle; + + static void SetUpTestSuite(); + static void TearDownTestSuite(); +}; \ No newline at end of file