mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Rework inception block tests
This commit is contained in:
50
examples/inception_v3/tests/test_fixture.hpp
Normal file
50
examples/inception_v3/tests/test_fixture.hpp
Normal file
@@ -0,0 +1,50 @@
|
||||
#ifndef TEST_FIXTURE_H
|
||||
#define TEST_FIXTURE_H
|
||||
|
||||
#include <cmath>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "inception_v3.hpp"
|
||||
|
||||
class InceptionBlockTest : public ::testing::Test {
|
||||
protected:
|
||||
CUDANet::Model *model;
|
||||
|
||||
cudaError_t cudaStatus;
|
||||
|
||||
shape2d inputShape;
|
||||
int inputChannels;
|
||||
|
||||
int outputSize;
|
||||
|
||||
std::vector<float> input;
|
||||
std::vector<float> expected;
|
||||
|
||||
virtual void SetUp() override {
|
||||
model = nullptr;
|
||||
}
|
||||
|
||||
virtual void TearDown() override {
|
||||
// Clean up
|
||||
delete model;
|
||||
}
|
||||
|
||||
void runTest() {
|
||||
EXPECT_EQ(
|
||||
input.size(), inputShape.first * inputShape.second * inputChannels
|
||||
);
|
||||
|
||||
float *output = model->predict(input.data());
|
||||
|
||||
cudaStatus = cudaGetLastError();
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
EXPECT_EQ(outputSize, expected.size());
|
||||
|
||||
for (int i = 0; i < outputSize; ++i) {
|
||||
EXPECT_NEAR(expected[i], output[i], 1e-3f);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user