From 03abd47f4945a8fdb9ffd0af76288493cd59afe1 Mon Sep 17 00:00:00 2001 From: AbdealiJK Date: Sun, 4 Dec 2016 13:15:09 +0530 Subject: [PATCH] tests/cpp: Add tests for Metric RMSE --- tests/cpp/helpers.cc | 11 +++++++++++ tests/cpp/helpers.h | 7 +++++++ tests/cpp/metric/test_elementwise_metric.cc | 14 ++++++++++++++ 3 files changed, 32 insertions(+) create mode 100644 tests/cpp/metric/test_elementwise_metric.cc diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 31349b3ab..a6dee1133 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -49,3 +49,14 @@ void CheckObjFunction(xgboost::ObjFunction * obj, << " weight=" << weights[i]; } } + +xgboost::bst_float GetMetricEval(xgboost::Metric * metric, + std::vector preds, + std::vector labels, + std::vector weights) { + xgboost::MetaInfo info; + info.num_row = labels.size(); + info.labels = labels; + info.weights = weights; + return metric->Eval(preds, info, false); +} diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 41945e3aa..94bef0771 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -13,6 +13,7 @@ #include #include +#include std::string TempFileName(); @@ -29,4 +30,10 @@ void CheckObjFunction(xgboost::ObjFunction * obj, std::vector out_grad, std::vector out_hess); +xgboost::bst_float GetMetricEval( + xgboost::Metric * metric, + std::vector preds, + std::vector labels, + std::vector weights = std::vector ()); + #endif diff --git a/tests/cpp/metric/test_elementwise_metric.cc b/tests/cpp/metric/test_elementwise_metric.cc new file mode 100644 index 000000000..8aba13e6a --- /dev/null +++ b/tests/cpp/metric/test_elementwise_metric.cc @@ -0,0 +1,14 @@ +// Copyright by Contributors +#include + +#include "../helpers.h" + +TEST(Metric, RMSE) { + xgboost::Metric * metric = xgboost::Metric::Create("rmse"); + ASSERT_STREQ(metric->Name(), "rmse"); + EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, + {0.1, 0.9, 0.1, 0.9}, + { 0, 0, 1, 1}), + 0.6403, 0.001); +}