From 47ba2de7d4b37fe30c0f0d2c35c7c800051ea3eb Mon Sep 17 00:00:00 2001 From: AbdealiJK Date: Sun, 4 Dec 2016 16:31:12 +0530 Subject: [PATCH] tests/cpp: Add tests for multiclass_metric.cc --- tests/cpp/objective/test_multiclass_metric.cc | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 tests/cpp/objective/test_multiclass_metric.cc diff --git a/tests/cpp/objective/test_multiclass_metric.cc b/tests/cpp/objective/test_multiclass_metric.cc new file mode 100644 index 000000000..9d7c023d6 --- /dev/null +++ b/tests/cpp/objective/test_multiclass_metric.cc @@ -0,0 +1,28 @@ +// Copyright by Contributors +#include + +#include "../helpers.h" + +TEST(Metric, MultiClassError) { + xgboost::Metric * metric = xgboost::Metric::Create("merror"); + ASSERT_STREQ(metric->Name(), "merror"); + EXPECT_ANY_THROW(GetMetricEval(metric, {0}, {0, 0})); + EXPECT_NEAR(GetMetricEval( + metric, {1, 0, 0, 0, 1, 0, 0, 0, 1}, {0, 1, 2}), 0, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, + {0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1}, + {0, 1, 2}), + 0.666, 0.001); +} + +TEST(Metric, MultiClassLogLoss) { + xgboost::Metric * metric = xgboost::Metric::Create("mlogloss"); + ASSERT_STREQ(metric->Name(), "mlogloss"); + EXPECT_ANY_THROW(GetMetricEval(metric, {0}, {0, 0})); + EXPECT_NEAR(GetMetricEval( + metric, {1, 0, 0, 0, 1, 0, 0, 0, 1}, {0, 1, 2}), 0, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, + {0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1}, + {0, 1, 2}), + 2.302, 0.001); +}