diff --git a/src/metric/metric.cc b/src/metric/metric.cc index db7753d68..0d73ee690 100644 --- a/src/metric/metric.cc +++ b/src/metric/metric.cc @@ -15,13 +15,22 @@ namespace xgboost { Metric* Metric::Create(const std::string& name, LearnerTrainParam const* tparam) { std::string buf = name; std::string prefix = name; + const char* param; auto pos = buf.find('@'); if (pos == std::string::npos) { - auto *e = ::dmlc::Registry< ::xgboost::MetricReg>::Get()->Find(name); + if (!buf.empty() && buf.back() == '-') { + // Metrics of form "metric-" + prefix = buf.substr(0, buf.length() - 1); // Chop off '-' + param = "-"; + } else { + prefix = buf; + param = nullptr; + } + auto *e = ::dmlc::Registry< ::xgboost::MetricReg>::Get()->Find(prefix.c_str()); if (e == nullptr) { LOG(FATAL) << "Unknown metric function " << name; } - auto p_metric = (e->body)(nullptr); + auto p_metric = (e->body)(param); p_metric->tparam_ = tparam; return p_metric; } else { diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index 6e4832109..bb1b053b7 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -285,10 +285,13 @@ struct EvalRankList : public Metric { minus_ = false; if (param != nullptr) { std::ostringstream os; - os << name << '@' << param; - name_ = os.str(); - if (sscanf(param, "%u[-]?", &topn_) != 1) { + if (sscanf(param, "%u[-]?", &topn_) == 1) { + os << name << '@' << param; + name_ = os.str(); + } else { topn_ = std::numeric_limits::max(); + os << name << param; + name_ = os.str(); } if (param[strlen(param) - 1] == '-') { minus_ = true; diff --git a/tests/cpp/metric/test_rank_metric.cc b/tests/cpp/metric/test_rank_metric.cc index 3aa8f0fa8..ef1d1377f 100644 --- a/tests/cpp/metric/test_rank_metric.cc +++ b/tests/cpp/metric/test_rank_metric.cc @@ -113,7 +113,18 @@ TEST(Metric, NDCG) { delete metric; metric = xgboost::Metric::Create("ndcg@-", &tparam); - ASSERT_STREQ(metric->Name(), "ndcg@-"); + ASSERT_STREQ(metric->Name(), "ndcg-"); + EXPECT_NEAR(GetMetricEval(metric, + xgboost::HostDeviceVector{}, + {}), 0, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, + {0.1f, 0.9f, 0.1f, 0.9f}, + { 0, 0, 1, 1}), + 0.6509f, 0.001f); + delete metric; + metric = xgboost::Metric::Create("ndcg-", &tparam); + ASSERT_STREQ(metric->Name(), "ndcg-"); EXPECT_NEAR(GetMetricEval(metric, xgboost::HostDeviceVector{}, {}), 0, 1e-10); @@ -150,7 +161,14 @@ TEST(Metric, MAP) { delete metric; metric = xgboost::Metric::Create("map@-", &tparam); - ASSERT_STREQ(metric->Name(), "map@-"); + ASSERT_STREQ(metric->Name(), "map-"); + EXPECT_NEAR(GetMetricEval(metric, + xgboost::HostDeviceVector{}, + {}), 0, 1e-10); + + delete metric; + metric = xgboost::Metric::Create("map-", &tparam); + ASSERT_STREQ(metric->Name(), "map-"); EXPECT_NEAR(GetMetricEval(metric, xgboost::HostDeviceVector{}, {}), 0, 1e-10);