Support ndcg- and map- (#4635)

This commit is contained in:
Philip Hyunsu Cho 2019-07-03 22:51:48 -07:00 committed by GitHub
parent 4e9fad74eb
commit 96bf91725b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 7 deletions

View File

@ -15,13 +15,22 @@ namespace xgboost {
Metric* Metric::Create(const std::string& name, LearnerTrainParam const* tparam) { Metric* Metric::Create(const std::string& name, LearnerTrainParam const* tparam) {
std::string buf = name; std::string buf = name;
std::string prefix = name; std::string prefix = name;
const char* param;
auto pos = buf.find('@'); auto pos = buf.find('@');
if (pos == std::string::npos) { if (pos == std::string::npos) {
auto *e = ::dmlc::Registry< ::xgboost::MetricReg>::Get()->Find(name); if (!buf.empty() && buf.back() == '-') {
// Metrics of form "metric-"
prefix = buf.substr(0, buf.length() - 1); // Chop off '-'
param = "-";
} else {
prefix = buf;
param = nullptr;
}
auto *e = ::dmlc::Registry< ::xgboost::MetricReg>::Get()->Find(prefix.c_str());
if (e == nullptr) { if (e == nullptr) {
LOG(FATAL) << "Unknown metric function " << name; LOG(FATAL) << "Unknown metric function " << name;
} }
auto p_metric = (e->body)(nullptr); auto p_metric = (e->body)(param);
p_metric->tparam_ = tparam; p_metric->tparam_ = tparam;
return p_metric; return p_metric;
} else { } else {

View File

@ -285,10 +285,13 @@ struct EvalRankList : public Metric {
minus_ = false; minus_ = false;
if (param != nullptr) { if (param != nullptr) {
std::ostringstream os; std::ostringstream os;
os << name << '@' << param; if (sscanf(param, "%u[-]?", &topn_) == 1) {
name_ = os.str(); os << name << '@' << param;
if (sscanf(param, "%u[-]?", &topn_) != 1) { name_ = os.str();
} else {
topn_ = std::numeric_limits<unsigned>::max(); topn_ = std::numeric_limits<unsigned>::max();
os << name << param;
name_ = os.str();
} }
if (param[strlen(param) - 1] == '-') { if (param[strlen(param) - 1] == '-') {
minus_ = true; minus_ = true;

View File

@ -113,7 +113,18 @@ TEST(Metric, NDCG) {
delete metric; delete metric;
metric = xgboost::Metric::Create("ndcg@-", &tparam); metric = xgboost::Metric::Create("ndcg@-", &tparam);
ASSERT_STREQ(metric->Name(), "ndcg@-"); ASSERT_STREQ(metric->Name(), "ndcg-");
EXPECT_NEAR(GetMetricEval(metric,
xgboost::HostDeviceVector<xgboost::bst_float>{},
{}), 0, 1e-10);
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10);
EXPECT_NEAR(GetMetricEval(metric,
{0.1f, 0.9f, 0.1f, 0.9f},
{ 0, 0, 1, 1}),
0.6509f, 0.001f);
delete metric;
metric = xgboost::Metric::Create("ndcg-", &tparam);
ASSERT_STREQ(metric->Name(), "ndcg-");
EXPECT_NEAR(GetMetricEval(metric, EXPECT_NEAR(GetMetricEval(metric,
xgboost::HostDeviceVector<xgboost::bst_float>{}, xgboost::HostDeviceVector<xgboost::bst_float>{},
{}), 0, 1e-10); {}), 0, 1e-10);
@ -150,7 +161,14 @@ TEST(Metric, MAP) {
delete metric; delete metric;
metric = xgboost::Metric::Create("map@-", &tparam); metric = xgboost::Metric::Create("map@-", &tparam);
ASSERT_STREQ(metric->Name(), "map@-"); ASSERT_STREQ(metric->Name(), "map-");
EXPECT_NEAR(GetMetricEval(metric,
xgboost::HostDeviceVector<xgboost::bst_float>{},
{}), 0, 1e-10);
delete metric;
metric = xgboost::Metric::Create("map-", &tparam);
ASSERT_STREQ(metric->Name(), "map-");
EXPECT_NEAR(GetMetricEval(metric, EXPECT_NEAR(GetMetricEval(metric,
xgboost::HostDeviceVector<xgboost::bst_float>{}, xgboost::HostDeviceVector<xgboost::bst_float>{},
{}), 0, 1e-10); {}), 0, 1e-10);