diff --git a/examples/inception_v3/inception_v3.hpp b/examples/inception_v3/inception_v3.hpp index ec01d2b..768fb15 100644 --- a/examples/inception_v3/inception_v3.hpp +++ b/examples/inception_v3/inception_v3.hpp @@ -23,7 +23,6 @@ class BasicConv2d : public CUDANet::Module { int getOutputChannels(); - private: int outputChannels; CUDANet::Layers::Conv2d *conv; diff --git a/examples/inception_v3/tests/inception_blocks.py b/examples/inception_v3/tests/inception_blocks.py index fbc5dc3..05d9cdf 100644 --- a/examples/inception_v3/tests/inception_blocks.py +++ b/examples/inception_v3/tests/inception_blocks.py @@ -2,6 +2,7 @@ import sys import torch from torchvision.models.inception import ( + BasicConv2d, InceptionA, InceptionB, InceptionC, @@ -26,6 +27,7 @@ class InceptionBlockModel(torch.nn.Module): x = self.inception_block(x) x = torch.flatten(x) x = self.fc(x) + # x = torch.nn.functional.tanh(x) return x @@ -37,6 +39,11 @@ def init_weights(m: torch.nn.Module): torch.nn.init.uniform_(m.weight, -1) torch.nn.init.uniform_(m.bias, 1) + if isinstance(m, torch.nn.BatchNorm2d): + # Initialize running_mean and running_var + m.running_mean.uniform_(-1, 1) + m.running_var.uniform_(0, 1) # Variance should be positive + @torch.no_grad() def generate_module_test_data(m: torch.nn.Module, name: str): @@ -64,17 +71,20 @@ def generate_module_test_data(m: torch.nn.Module, name: str): if __name__ == "__main__": - m = InceptionA(3, 6) - generate_module_test_data(m, "inception_a") + # m = BasicConv2d(3, 32, kernel_size=3, stride=2, padding=0) + # generate_module_test_data(m, "basic_conv2d") - m = InceptionB(3) - generate_module_test_data(m, "inception_b") + # m = InceptionA(3, 6) + # generate_module_test_data(m, "inception_a") - m = InceptionC(3, 64) - generate_module_test_data(m, "inception_c") + # m = InceptionB(3) + # generate_module_test_data(m, "inception_b") - m = InceptionD(3) - generate_module_test_data(m, "inception_d") + # m = InceptionC(3, 64) + # generate_module_test_data(m, "inception_c") + + # m = InceptionD(3) + # generate_module_test_data(m, "inception_d") m = InceptionE(3) generate_module_test_data(m, "inception_e") diff --git a/examples/inception_v3/tests/test_basic_conv2d.cpp b/examples/inception_v3/tests/test_basic_conv2d.cpp index b91c6fc..311b818 100644 --- a/examples/inception_v3/tests/test_basic_conv2d.cpp +++ b/examples/inception_v3/tests/test_basic_conv2d.cpp @@ -2,234 +2,71 @@ #include -class BasicConv2dTest : public ::testing::Test { - protected: - BasicConv2d *basic_conv2d; - - shape2d inputShape; - int inputChannels; - int outputChannels; - shape2d kernelSize; - shape2d stride; - shape2d padding; - std::string prefix = "test."; - - float *d_input; - float *d_output; - - std::vector input; - std::vector expected; - - std::vector convWeights; - std::vector convBiases; - - std::vector bnWeights; - std::vector bnBiases; - - virtual void SetUp() override { - basic_conv2d = nullptr; - } - - virtual void TearDown() override { - // Clean up - delete basic_conv2d; - } - - void runTest() { - cudaError_t cudaStatus; +#include "test_fixture.hpp" +class BasicConv2dModel : public CUDANet::Model { + public: + BasicConv2dModel( + const shape2d inputShape, + const int inputChannels, + const int outputChannels, + const int outputSize + ) + : CUDANet::Model(inputShape, inputChannels, outputSize) { basic_conv2d = new BasicConv2d( - inputShape, inputChannels, outputChannels, kernelSize, stride, - padding, prefix + inputShape, inputChannels, outputChannels, {3, 3}, {2, 2}, {0, 0}, + "inception_block." ); - - std::pair layerPair = - basic_conv2d->getLayers()[0]; - - ASSERT_EQ(layerPair.first, prefix + "conv"); - - CUDANet::Layers::Conv2d *conv = - dynamic_cast(layerPair.second); - conv->setWeights(convWeights.data()); - conv->setBiases(convBiases.data()); - - ASSERT_EQ(conv->getWeights().size(), convWeights.size()); - ASSERT_EQ(conv->getBiases().size(), convBiases.size()); - - cudaStatus = cudaGetLastError(); - EXPECT_EQ(cudaStatus, cudaSuccess); - - layerPair = basic_conv2d->getLayers()[1]; - ASSERT_EQ(layerPair.first, prefix + "bn"); - - CUDANet::Layers::BatchNorm2d *bn = - dynamic_cast(layerPair.second); - bn->setWeights(bnWeights.data()); - bn->setBiases(bnBiases.data()); - - ASSERT_EQ(bn->getWeights().size(), bnWeights.size()); - ASSERT_EQ(bn->getBiases().size(), bnBiases.size()); - - cudaStatus = cudaGetLastError(); - EXPECT_EQ(cudaStatus, cudaSuccess); - - cudaStatus = - cudaMalloc((void **)&d_input, sizeof(float) * input.size()); - EXPECT_EQ(cudaStatus, cudaSuccess); - - cudaStatus = cudaMemcpy( - d_input, input.data(), sizeof(float) * input.size(), - cudaMemcpyHostToDevice + addLayer("", basic_conv2d); + fc = new CUDANet::Layers::Dense( + basic_conv2d->getOutputSize(), 50, + CUDANet::Layers::ActivationType::NONE ); - EXPECT_EQ(cudaStatus, cudaSuccess); + addLayer("fc", fc); + }; - d_output = basic_conv2d->forward(d_input); - - cudaStatus = cudaGetLastError(); - EXPECT_EQ(cudaStatus, cudaSuccess); - - int outputSize = basic_conv2d->getOutputSize(); - std::vector output(outputSize); - cudaStatus = cudaMemcpy( - output.data(), d_output, sizeof(float) * output.size(), - cudaMemcpyDeviceToHost - ); - EXPECT_EQ(cudaStatus, cudaSuccess); - - for (int i = 0; i < output.size(); ++i) { - EXPECT_NEAR(expected[i], output[i], 1e-5f); - } + float *predict(const float *input) override { + float *d_input = inputLayer->forward(input); + d_input = basic_conv2d->forward(d_input); + d_input = fc->forward(d_input); + return outputLayer->forward(d_input); } + + private: + BasicConv2d *basic_conv2d; + CUDANet::Layers::Dense *fc; }; -TEST_F(BasicConv2dTest, BasicConv2dTest1) { - inputShape = {8, 8}; - inputChannels = 3; - outputChannels = 6; - kernelSize = {3, 3}; - stride = {1, 1}; - padding = {1, 1}; +TEST_F(InceptionBlockTest, BasicConv2dTest) { + inputShape = {4, 4}; + inputChannels = 3; + outputSize = 50; - // 3x3x3x6 - convWeights = { - 0.18365f, 0.08568f, 0.08126f, 0.68022f, 0.41391f, 0.71204f, 0.66917f, - 0.63586f, 0.28914f, 0.43624f, 0.03018f, 0.47986f, 0.71336f, 0.82706f, - 0.587f, 0.58516f, 0.29813f, 0.19312f, 0.42975f, 0.62522f, 0.34256f, - 0.28057f, 0.37367f, 0.54325f, 0.63421f, 0.46445f, 0.56908f, 0.95247f, - 0.73934f, 0.51263f, 0.14464f, 0.0956f, 0.68846f, 0.14675f, 0.75427f, - 0.50547f, 0.37078f, 0.03316f, 0.42855f, 0.94293f, 0.73855f, 0.86475f, - 0.20687f, 0.37793f, 0.77947f, 0.24402f, 0.07547f, 0.22212f, 0.57188f, - 0.5098f, 0.71999f, 0.63828f, 0.53237f, 0.42874f, 0.43621f, 0.87348f, - 0.0073f, 0.07752f, 0.45232f, 0.78307f, 0.74813f, 0.73456f, 0.0378f, - 0.78518f, 0.6989f, 0.50484f, 0.74265f, 0.39178f, 0.91015f, 0.11684f, - 0.11499f, 0.10394f, 0.30637f, 0.86116f, 0.63743f, 0.64142f, 0.97882f, - 0.30948f, 0.32144f, 0.76108f, 0.81794f, 0.50111f, 0.82209f, 0.49028f, - 0.79417f, 0.3257f, 0.32221f, 0.4007f, 0.86371f, 0.2271f, 0.9414f, - 0.66233f, 0.60802f, 0.65701f, 0.41021f, 0.1135f, 0.21892f, 0.93389f, - 0.65786f, 0.26068f, 0.59535f, 0.15048f, 0.48185f, 0.91072f, 0.18252f, - 0.64154f, 0.89179f, 0.54726f, 0.60756f, 0.31149f, 0.30717f, 0.79877f, - 0.71727f, 0.12418f, 0.48471f, 0.46097f, 0.66898f, 0.35467f, 0.38027f, - 0.16989f, 0.88578f, 0.84377f, 0.26529f, 0.26057f, 0.30256f, 0.84876f, - 0.8849f, 0.08982f, 0.88191f, 0.1944f, 0.42052f, 0.62898f, 0.692f, - 0.51155f, 0.99903f, 0.56947f, 0.73144f, 0.88091f, 0.28472f, 0.98895f, - 0.41364f, 0.1927f, 0.07227f, 0.421f, 0.85347f, 0.19329f, 0.07098f, - 0.19418f, 0.06585f, 0.49083f, 0.85071f, 0.96747f, 0.45057f, 0.54361f, - 0.49552f, 0.23454f, 0.97412f, 0.26663f, 0.09274f, 0.1662f, 0.04784f, - 0.76303f - }; - convBiases.resize(outputChannels, 0.0f); + int outputChannels = 32; - bnWeights = {0.69298f, 0.27049f, 0.85854f, 0.52973f, 0.29644f, 0.68932f}; - bnBiases = {0.74976f, 0.42745f, 0.22132f, 0.21262f, 0.03726f, 0.9719f}; + model = new BasicConv2dModel( + inputShape, inputChannels, outputChannels, outputSize + ); + model->loadWeights("../tests/resources/basic_conv2d.bin"); - input = { - 0.75539f, 0.17641f, 0.8331f, 0.80627f, 0.51712f, 0.87756f, 0.97027f, - 0.21354f, 0.28498f, 0.05118f, 0.37124f, 0.40528f, 0.13661f, 0.08692f, - 0.73809f, 0.57278f, 0.73534f, 0.31338f, 0.15362f, 0.80245f, 0.49524f, - 0.81208f, 0.24074f, 0.42534f, 0.62236f, 0.75915f, 0.06382f, 0.66723f, - 0.13448f, 0.96896f, 0.87197f, 0.67366f, 0.67885f, 0.49345f, 0.08446f, - 0.94116f, 0.8659f, 0.22848f, 0.53262f, 0.51307f, 0.89661f, 0.72223f, - 0.90541f, 0.47353f, 0.85476f, 0.04177f, 0.04039f, 0.7917f, 0.56188f, - 0.53777f, 0.91714f, 0.84847f, 0.16995f, 0.59803f, 0.05454f, 0.00365f, - 0.01429f, 0.42586f, 0.31519f, 0.222f, 0.9149f, 0.51885f, 0.82969f, - 0.42778f, 0.82913f, 0.01303f, 0.92699f, 0.09225f, 0.00284f, 0.75769f, - 0.74072f, 0.59012f, 0.40777f, 0.0469f, 0.08751f, 0.23163f, 0.51327f, - 0.67095f, 0.31971f, 0.97841f, 0.82292f, 0.58917f, 0.31565f, 0.4728f, - 0.41885f, 0.36524f, 0.28194f, 0.70945f, 0.36008f, 0.23199f, 0.71093f, - 0.33364f, 0.34199f, 0.42114f, 0.40026f, 0.77819f, 0.79858f, 0.93793f, - 0.45238f, 0.97922f, 0.73814f, 0.11831f, 0.08414f, 0.56552f, 0.99841f, - 0.53862f, 0.71138f, 0.42274f, 0.48724f, 0.48201f, 0.5361f, 0.97138f, - 0.27607f, 0.33018f, 0.07456f, 0.77788f, 0.58824f, 0.77027f, 0.3938f, - 0.28081f, 0.14074f, 0.06907f, 0.75419f, 0.11888f, 0.35715f, 0.34481f, - 0.05669f, 0.21063f, 0.8664f, 0.00087f, 0.88281f, 0.55202f, 0.68655f, - 0.96262f, 0.53907f, 0.9227f, 0.74055f, 0.84487f, 0.22792f, 0.83233f, - 0.42938f, 0.39054f, 0.59604f, 0.4141f, 0.25982f, 0.9311f, 0.35475f, - 0.71432f, 0.29186f, 0.16604f, 0.90708f, 0.00171f, 0.11541f, 0.35719f, - 0.9221f, 0.18793f, 0.90198f, 0.29281f, 0.72144f, 0.54645f, 0.71165f, - 0.59584f, 0.24041f, 0.60954f, 0.64945f, 0.8122f, 0.34145f, 0.92178f, - 0.99894f, 0.25076f, 0.45067f, 0.71997f, 0.09573f, 0.57334f, 0.63273f, - 0.49469f, 0.72747f, 0.33449f, 0.13755f, 0.49458f, 0.50319f, 0.91328f, - 0.57269f, 0.21927f, 0.36831f, 0.88708f, 0.62277f, 0.08318f, 0.01425f, - 0.17998f, 0.34614f, 0.82303f - }; - - expected = { - 0.0f, 0.49814f, 0.22097f, 0.3619f, 0.46957f, 0.69706f, 1.06759f, - 0.25578f, 0.0f, 0.91978f, 0.53499f, 0.78382f, 1.13748f, 1.27999f, - 1.39561f, 0.59403f, 0.1681f, 1.1653f, 0.9397f, 0.99945f, 1.09875f, - 1.11738f, 1.48957f, 0.39551f, 0.17473f, 1.36075f, 1.38633f, 1.10036f, - 1.66809f, 1.24004f, 1.51673f, 0.35859f, 0.50363f, 1.90002f, 1.76062f, - 1.77264f, 1.653f, 0.98297f, 0.97645f, 0.36179f, 0.65388f, 1.82326f, - 1.62819f, 1.53234f, 1.52987f, 1.1909f, 1.19085f, 0.0f, 0.0f, - 1.00418f, 0.9884f, 1.06528f, 1.10918f, 0.95965f, 1.01066f, 0.0f, - 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.06699f, 0.0f, - 0.0f, 0.0f, 0.31227f, 0.1577f, 0.24142f, 0.29244f, 0.35219f, - 0.55728f, 0.09206f, 0.18279f, 0.52608f, 0.43298f, 0.57281f, 0.64957f, - 0.67697f, 0.79076f, 0.25769f, 0.17322f, 0.45144f, 0.50649f, 0.44384f, - 0.45046f, 0.52827f, 0.65169f, 0.26233f, 0.33391f, 0.54569f, 0.61824f, - 0.71162f, 0.72201f, 0.59606f, 0.69006f, 0.17808f, 0.53409f, 0.84795f, - 0.81671f, 0.72767f, 0.70439f, 0.49824f, 0.77586f, 0.28972f, 0.41066f, - 0.78739f, 0.74518f, 0.69849f, 0.72851f, 0.58154f, 0.59843f, 0.0988f, - 0.12992f, 0.69539f, 0.58411f, 0.53047f, 0.67763f, 0.45745f, 0.42961f, - 0.02356f, 0.0f, 0.1524f, 0.17941f, 0.20621f, 0.07853f, 0.0f, - 0.01425f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, - 0.0f, 0.0f, 0.0f, 0.0f, 0.53197f, 0.23141f, 0.65858f, - 0.51061f, 1.18983f, 1.88715f, 0.0f, 0.0f, 0.48249f, 0.27706f, - 0.4758f, 0.37868f, 0.19115f, 1.3417f, 0.0f, 0.0f, 0.79729f, - 0.40467f, 0.75802f, 1.25205f, 1.05397f, 0.99662f, 0.0f, 0.05866f, - 1.25683f, 1.37623f, 1.3692f, 0.8155f, 0.79031f, 0.79231f, 0.0f, - 0.66813f, 1.55738f, 0.86795f, 1.74891f, 1.46206f, 0.44267f, 0.71223f, - 0.0f, 0.01532f, 0.9517f, 0.9068f, 0.04987f, 0.68475f, 0.60834f, - 0.5695f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, - 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, - 0.0f, 0.0f, 0.13772f, 0.0f, 0.0f, 0.54903f, 0.17714f, - 0.56106f, 0.37474f, 0.59682f, 0.80188f, 0.23357f, 0.0f, 0.3935f, - 0.10723f, 0.21271f, 0.2933f, 0.40208f, 0.98239f, 0.19075f, 0.06934f, - 0.69707f, 0.59654f, 0.72836f, 0.94042f, 0.29819f, 0.65969f, 0.15544f, - 0.21691f, 0.94429f, 0.74025f, 0.57482f, 0.85235f, 0.6364f, 0.64997f, - 0.43117f, 0.23959f, 0.86925f, 0.74496f, 1.18404f, 0.91728f, 0.66074f, - 0.14145f, 0.0f, 0.0f, 0.82383f, 0.54479f, 0.37769f, 0.37376f, - 0.18698f, 0.41482f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, - 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, - 0.0f, 0.0f, 0.0f, 0.19054f, 0.0f, 0.0f, 0.13366f, - 0.02072f, 0.17679f, 0.21344f, 0.22093f, 0.39159f, 0.0f, 0.0f, - 0.21636f, 0.1152f, 0.05384f, 0.17127f, 0.31197f, 0.26403f, 0.0f, - 0.0f, 0.2079f, 0.40094f, 0.25855f, 0.2949f, 0.21378f, 0.29504f, - 0.0f, 0.0f, 0.55198f, 0.28422f, 0.44235f, 0.39818f, 0.24589f, - 0.24885f, 0.0f, 0.0f, 0.39978f, 0.49578f, 0.31662f, 0.57204f, - 0.22104f, 0.09188f, 0.0f, 0.0f, 0.30446f, 0.11957f, 0.18297f, - 0.21063f, 0.11165f, 0.1131f, 0.0f, 0.0f, 0.0f, 0.0f, - 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.04903f, - 0.0f, 0.21626f, 0.35491f, 0.86898f, 0.9025f, 0.0f, 0.36255f, - 1.46154f, 1.38429f, 1.44938f, 1.41407f, 1.45809f, 1.77706f, 0.88361f, - 0.09394f, 0.92029f, 1.01541f, 1.09078f, 1.05394f, 1.25418f, 1.40895f, - 0.78881f, 0.62721f, 1.55362f, 1.70365f, 1.83765f, 1.7833f, 1.52613f, - 1.39727f, 0.44845f, 0.80839f, 1.73151f, 1.63702f, 1.60352f, 1.63081f, - 1.5767f, 1.99697f, 0.91883f, 0.62179f, 1.8053f, 1.63263f, 1.72401f, - 2.45383f, 1.25455f, 1.07616f, 0.38183f, 0.56256f, 1.8342f, 1.49708f, - 1.54651f, 0.90693f, 0.85377f, 0.9732f, 0.0f, 0.0f, 0.42826f, - 0.47554f, 0.23275f, 0.5115f, 0.14327f, 0.23193f, 0.0f - }; + input = {3.11462f, -0.81077f, -1.10521f, -0.72331f, 0.78823f, 1.36453f, + 0.37365f, -1.00043f, 0.00156f, -0.13156f, 0.10315f, -0.36979f, + -0.7116f, -0.1203f, 1.23831f, 0.19852f, -0.79851f, 0.27605f, + 0.09819f, 2.48209f, -1.20067f, 1.02096f, -0.38697f, -0.05689f, + -0.27344f, -0.06105f, 0.53209f, -0.89718f, -1.30166f, -1.37283f, + 1.69093f, -0.4622f, 0.20359f, -1.03283f, 1.13048f, -0.5703f, + -2.10094f, 0.38992f, 0.08734f, -0.85736f, -0.27462f, 0.44321f, + 0.95911f, 1.33195f, 0.77331f, 3.0567f, -2.4878f, -1.69617f}; + expected = {-31.72671f, -26.44757f, -10.69324f, 34.00166f, 17.61435f, + 27.69349f, 3.65285f, -23.72797f, -4.0109f, 15.94939f, + -3.76765f, 24.18737f, 9.72213f, 0.57065f, 29.3591f, + 3.71433f, 18.16798f, -6.33768f, -21.24062f, 1.64767f, + -9.12255f, -17.8523f, 7.95207f, 5.59372f, 20.16168f, + 2.86041f, -16.32113f, 29.07518f, 4.09777f, 8.65735f, + 1.00664f, -23.30697f, 7.38348f, 24.17699f, -15.56942f, + 9.82751f, -8.22256f, 0.88987f, 4.39324f, 11.62732f, + -6.18404f, 11.78396f, -37.57986f, 28.56649f, -16.00127f, + -30.68929f, 30.1704f, 14.04265f, -2.70738f, 4.89308f}; runTest(); -}; \ No newline at end of file +} \ No newline at end of file diff --git a/examples/inception_v3/tests/test_fixture.hpp b/examples/inception_v3/tests/test_fixture.hpp index 13b2ec2..742a3dd 100644 --- a/examples/inception_v3/tests/test_fixture.hpp +++ b/examples/inception_v3/tests/test_fixture.hpp @@ -6,6 +6,11 @@ #include "inception_v3.hpp" +bool __inline__ isCloseRelative(float a, float b, float rel_tol = 1e-3f, float abs_tol = 1e-3f) { + return std::abs(a - b) <= std::max(rel_tol * std::max(std::abs(a), std::abs(b)), abs_tol); +}; + + class InceptionBlockTest : public ::testing::Test { protected: CUDANet::Model *model; @@ -42,7 +47,7 @@ class InceptionBlockTest : public ::testing::Test { EXPECT_EQ(outputSize, expected.size()); for (int i = 0; i < outputSize; ++i) { - EXPECT_NEAR(expected[i], output[i], 1e-3f); + EXPECT_TRUE(isCloseRelative(expected[i], output[i])); } } }; diff --git a/examples/inception_v3/tests/test_inception_a.cpp b/examples/inception_v3/tests/test_inception_a.cpp index 81aee51..616cdb0 100644 --- a/examples/inception_v3/tests/test_inception_a.cpp +++ b/examples/inception_v3/tests/test_inception_a.cpp @@ -42,25 +42,27 @@ TEST_F(InceptionBlockTest, InceptionATest) { model = new InceptionAModel(inputShape, inputChannels, outputSize); model->loadWeights("../tests/resources/inception_a.bin"); - input = {-0.57666f, -1.62078f, -1.33839f, -0.84871f, 0.95678f, 0.5095f, - 0.46427f, 1.63513f, -0.86174f, 1.37982f, 2.11955f, -0.56532f, - -1.41889f, -0.24865f, 0.75658f, 1.41115f, -0.04036f, -0.13206f, - 0.54325f, -0.90184f, -0.30188f, -2.06574f, -0.12676f, 0.38189f, - 1.7959f, 0.24076f, 1.17587f, -0.21496f, 0.55819f, 0.21572f, - -1.66043f, 1.24566f, 0.837f, 0.13259f, -0.73019f, 0.87461f, - 1.38548f, -0.48258f, -0.11748f, 0.4244f, 1.14489f, 0.28394f, - -0.46594f, 1.18402f, -0.91973f, 0.63682f, -0.31897f, 0.80855f}; + input = {0.0961f, -2.01883f, -1.09473f, -0.86616f, -0.11743f, 0.30927f, + 0.89226f, -0.86077f, -0.54149f, 0.06902f, 0.41018f, 1.12389f, + -1.94536f, -1.01233f, -1.93574f, -0.37831f, 0.30134f, -0.71024f, + -0.85796f, -1.60188f, 0.16672f, 0.04074f, -0.17714f, -0.45344f, + -0.67299f, 0.39537f, 1.36158f, -0.04113f, 0.04399f, 0.08246f, + 0.28653f, 0.00399f, -0.76861f, -0.1379f, 1.23108f, 1.15452f, + -1.67101f, 0.892f, -0.22458f, 1.08748f, 0.63386f, 1.91399f, + 0.69495f, 1.39091f, -0.38628f, -1.3974f, -0.74191f, 0.40352f}; - expected = {4.80097f, 47.97113f, 2.84091f, -8.17906f, -27.73839f, - -20.00487f, 25.13571f, 55.96156f, -57.32637f, 46.79503f, - 30.84768f, -2.27363f, 11.58069f, -84.01064f, -86.74448f, - -34.90844f, -31.9425f, -26.9795f, 43.22921f, 18.58556f, - 19.94732f, 9.99053f, 77.01399f, -29.40551f, 22.79751f, - -25.38616f, 78.7154f, -3.62437f, -7.37189f, -37.58518f, - 28.78344f, 46.85378f, -84.57623f, 0.10005f, 64.6466f, - -10.21144f, 51.44754f, 15.37502f, 24.96819f, -34.59124f, - 2.73933f, -32.52842f, -1.32425f, 41.48183f, -12.74939f, - -102.07105f, 5.58513f, 9.73683f, -25.97733f, -24.79673f}; + expected = { + 5525.64062f, -761.25549f, -2780.82837f, 1123.72534f, -5405.26465f, + -840.91406f, 3590.53394f, -3732.77344f, 945.03845f, 1172.73401f, + -1085.39026f, 1690.71399f, 2042.38208f, -5948.82129f, 2648.69897f, + 6884.2876f, -1833.52173f, 3289.4668f, 110.44409f, -1192.91907f, + 6087.70117f, 8234.74316f, 4488.75488f, 7.75244f, -2987.04834f, + -5129.2124f, -3235.24585f, -336.58179f, 2506.66943f, 598.82483f, + 488.68921f, 913.30005f, -6063.51318f, 3496.71753f, -4504.59473f, + 1082.13867f, 3889.91968f, 1248.47168f, -742.7981f, -2244.45215f, + -1985.24561f, 14646.59766f, -310.81506f, 4763.40527f, -3007.1792f, + 382.23853f, -2357.31445f, 3503.68457f, -5159.74902f, -5777.59863f + }; runTest(); } \ No newline at end of file diff --git a/examples/inception_v3/tests/test_inception_b.cpp b/examples/inception_v3/tests/test_inception_b.cpp index 583bfc7..90ead54 100644 --- a/examples/inception_v3/tests/test_inception_b.cpp +++ b/examples/inception_v3/tests/test_inception_b.cpp @@ -51,16 +51,18 @@ TEST_F(InceptionBlockTest, InceptionBTest) { 0.98517f, 0.32883f, -0.28338f, 0.81102f, 0.70454f, 0.84246f, 0.93766f, -0.83322f, 0.58987f, 1.23888f, -0.6962f, 0.68079f}; - expected = {-19.94974f, 93.30141f, 85.76035f, 214.1964f, -4.30855f, - -92.65581f, -37.0993f, -96.92029f, -145.99411f, 177.49068f, - 185.37115f, 263.2403f, -158.78972f, -435.17844f, 208.36617f, - -481.91907f, 179.2296f, 42.68469f, 661.90039f, 192.62759f, - 146.78622f, 59.37774f, -107.44885f, 578.51874f, 745.61536f, - 528.30847f, -54.60599f, 237.63603f, 198.97778f, 9.95003f, - 61.68781f, -156.87708f, -166.12646f, 294.20853f, 382.63782f, - 53.15688f, -74.18913f, 7.70657f, 202.17197f, -121.06818f, - 32.45838f, -401.04041f, 94.91491f, 240.54332f, 171.52502f, - -121.58651f, 419.89447f, 161.91119f, -13.53201f, -74.76675f}; + expected = { + -3051.75049f, -4.80383f, -1978.21191f, -522.09924f, 1021.27625f, + 2102.80273f, 875.50775f, -466.79095f, -706.03009f, 2394.21826f, + 1953.83984f, -1130.63367f, 1569.5769f, -12.87457f, -502.60977f, + 593.30615f, 104.63843f, -1463.10815f, -1655.98389f, -1414.4104f, + -366.82794f, -2672.62769f, -1057.31287f, 832.19531f, -116.99335f, + 476.57092f, -1208.35327f, 357.08228f, 2724.59375f, 1238.1272f, + 1124.98877f, 566.73798f, -2852.34058f, -98.82605f, -4457.8584f, + 1228.86597f, 1112.53467f, 2053.17212f, 396.1055f, -2534.39136f, + 1349.99756f, -792.96722f, -1477.9967f, 450.82751f, -297.31879f, + 294.22925f, 3548.19165f, -63.16211f, -1651.43335f, 51.88046f + }; runTest(); } \ No newline at end of file diff --git a/examples/inception_v3/tests/test_inception_c.cpp b/examples/inception_v3/tests/test_inception_c.cpp index ef0a098..a0ad392 100644 --- a/examples/inception_v3/tests/test_inception_c.cpp +++ b/examples/inception_v3/tests/test_inception_c.cpp @@ -51,19 +51,19 @@ TEST_F(InceptionBlockTest, InceptionCTest) { -1.60163f, -0.27932f, -0.20508f, 1.31193f, -0.7601f, -0.0586f, -0.21923f, -0.85385f, -1.10512f, -0.22181f, 0.94507f, -0.09808f}; - expected = {-9231.45508f, -11854.50684f, -2690.15942f, -6366.60303f, - -6953.01855f, -2204.80371f, 1670.89551f, 18207.81641f, - -8896.50977f, 10661.94434f, 3338.31055f, -3853.95947f, - 1445.87354f, -9627.54297f, 4166.00635f, -22477.38477f, - 11400.2207f, 8139.3877f, 8114.41602f, -2006.37793f, - -9130.33398f, 10554.69824f, 5194.41016f, -7031.67969f, - -10880.09277f, -4093.95068f, 6500.65967f, -459.13672f, - -10640.70215f, 6096.37842f, 12178.46094f, 5894.95117f, - -3034.80225f, -5177.80518f, -6112.60449f, -7296.75879f, - -1134.77344f, -13472.27637f, 8982.56543f, -3773.67334f, - -4207.74609f, -4001.82129f, -6682.51953f, -12314.57617f, - -6180.21875f, -886.62231f, 5490.0752f, 4868.64893f, - -12725.73633f, -3121.33716f}; + expected = {386048.9375f, -721189.75f, 1039515.0625f, -92812.95312f, + 533437.75f, -617244.0f, -81946.21094f, -775994.25f, + -653376.0f, -690453.25f, 218790.28125f, 454025.3125f, + 947592.375f, 280879.25f, -61118.59375f, -88742.75781f, + -458026.0625f, 82204.71875f, -297425.9375f, 114420.0625f, + 397277.71875f, 593181.375f, 582754.125f, -614345.1875f, + -173317.15625f, -220982.48438f, -932588.5625f, 339467.5625f, + 917578.125f, -95884.16406f, 83229.875f, 434552.375f, + 231232.1875f, 142239.71875f, -264704.5f, 854149.0f, + 462348.21875f, 33728.0f, 24409.39062f, -509526.3125f, + -279235.5625f, 570330.0f, 103149.71875f, 26780.33984f, + -328880.71875f, 1027994.8125f, -585315.0f, -210921.71875f, + 492957.53125f, -122604.57031f}; runTest(); } diff --git a/examples/inception_v3/tests/test_inception_d.cpp b/examples/inception_v3/tests/test_inception_d.cpp index 7cb02fd..fc4e1bd 100644 --- a/examples/inception_v3/tests/test_inception_d.cpp +++ b/examples/inception_v3/tests/test_inception_d.cpp @@ -51,18 +51,19 @@ TEST_F(InceptionBlockTest, InceptionDTest) { -0.89471f, -0.0348f, -1.49654f, -1.18578f, -2.013f, -0.47656f, -0.16578f, 0.21603f, -0.23605f, -0.53382f, -0.25789f, 2.30887f}; - expected = { - -778.66046f, 2780.01416f, -908.03717f, 720.61572f, 975.1803f, - -2017.04016f, 2678.03955f, -2089.99609f, -1231.16272f, 4078.28247f, - -765.89209f, -2531.9021f, -1590.11182f, 6677.42822f, 174.45618f, - -1065.43262f, 4505.68066f, 3798.1748f, 1419.7229f, 2433.96948f, - 355.61597f, 1356.61279f, -2179.37061f, -973.08789f, 2414.1543f, - -2190.11792f, -157.86133f, 1810.07166f, 2140.48706f, 8073.00488f, - 2629.58789f, 4686.91992f, -3285.09985f, 5723.23584f, 1181.26648f, - -5476.90723f, 4895.85547f, -1787.32935f, 2138.9646f, 1336.84277f, - -3492.97656f, 3706.74121f, 703.98871f, -2263.92188f, 4441.91016f, - -3471.9314f, 3354.59106f, 5038.75928f, -3676.13037f, 563.34637f - }; + expected = {-21052.13086f, 97856.53125f, 114996.78125f, -26694.66602f, + -51989.99219f, -41073.51562f, 52375.89844f, -101566.1875f, + 110595.30469f, -34081.6875f, 41151.85938f, -116816.51562f, + 12594.64941f, -86867.95312f, -103277.80469f, -31095.63281f, + 30530.58984f, -47046.89844f, 94815.74219f, -24208.12891f, + -50130.52734f, 38272.71094f, 102970.35938f, 92221.41406f, + 20659.89258f, -60365.08984f, 10940.85938f, -48804.74219f, + 119315.45312f, -49296.32031f, -113509.04688f, -19691.87305f, + -62688.6875f, -94743.73438f, -77935.0f, -84231.10156f, + 58992.52344f, -23301.23828f, -34058.94531f, -27215.86328f, + -103682.59375f, 13735.66992f, 7671.27002f, -68139.50781f, + -59972.78125f, -7613.14844f, -34182.88281f, 29532.60352f, + -71745.53906f, -137596.1875f}; runTest(); } \ No newline at end of file diff --git a/examples/inception_v3/tests/test_inception_e.cpp b/examples/inception_v3/tests/test_inception_e.cpp index 8585147..c43c8c7 100644 --- a/examples/inception_v3/tests/test_inception_e.cpp +++ b/examples/inception_v3/tests/test_inception_e.cpp @@ -51,19 +51,19 @@ TEST_F(InceptionBlockTest, InceptionETest) { 2.1034f, 1.65832f, 1.63788f, -1.32596f, -1.43412f, -1.28353f, 0.70226f, 0.9459f, 0.8579f, 0.15361f, 0.34449f, -1.70587f}; - expected = {1614.15283f, -11319.01855f, 614.40479f, 5280.0293f, - 1914.45007f, -2937.50317f, -11177.16113f, 3215.01245f, - 6249.16992f, 5654.91357f, -11702.27148f, 13057.32422f, - 8665.35742f, 3911.11743f, 5239.45947f, -11552.88477f, - -8056.7666f, -16426.19922f, -1383.04346f, 6573.53125f, - -12226.16992f, -6641.0957f, -9614.80078f, -9313.30273f, - 7023.68848f, 2089.5752f, 1095.53369f, -1387.65698f, - -7928.21729f, -9489.18848f, 4159.78613f, -690.03442f, - -8356.81738f, 12364.08203f, 8226.95703f, 8822.66602f, - -5462.90381f, -1037.42773f, 12958.68555f, -666.58423f, - 2032.38574f, -9534.14062f, -947.41333f, 689.37158f, - 4585.76465f, -23245.36719f, 975.83398f, -1253.45703f, - -14745.35059f, -2588.05493f}; + expected = {-52475.21094f, -45850.59766f, 25258.94727f, -123668.88281f, + -124592.32812f, 120878.47656f, 69247.67188f, 3390.39258f, + -17620.58594f, 5239.70117f, -30841.2793f, -134645.84375f, + -71254.0f, -69958.625f, 27372.9668f, -10891.0293f, + 52875.20703f, 810.01172f, -57457.20312f, -26664.05469f, + -8147.90527f, -139440.09375f, -71311.84375f, -53446.54688f, + 25358.27148f, -42854.97656f, 57698.98438f, 63391.79688f, + 54427.98438f, 89160.73438f, 79430.96094f, -51700.30469f, + 29048.21094f, -28000.3418f, -29570.61133f, -16047.83691f, + -69285.42188f, -13865.00391f, 17681.38672f, -45284.46484f, + 42490.97656f, 30390.58203f, 21886.40039f, -89973.20312f, + 75571.00781f, 19183.16797f, -37130.51562f, 12787.17383f, + 59336.42578f, -88201.78125f}; runTest(); } \ No newline at end of file