Support ndcg- and map- (#4635)

This commit is contained in:
Philip Hyunsu Cho 2019-07-03 22:51:48 -07:00 committed by GitHub
parent 4e9fad74eb
commit 96bf91725b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 7 deletions

View File

@ -15,13 +15,22 @@ namespace xgboost {
Metric* Metric::Create(const std::string& name, LearnerTrainParam const* tparam) {
std::string buf = name;
std::string prefix = name;
const char* param;
auto pos = buf.find('@');
if (pos == std::string::npos) {
auto *e = ::dmlc::Registry< ::xgboost::MetricReg>::Get()->Find(name);
if (!buf.empty() && buf.back() == '-') {
// Metrics of form "metric-"
prefix = buf.substr(0, buf.length() - 1); // Chop off '-'
param = "-";
} else {
prefix = buf;
param = nullptr;
}
auto *e = ::dmlc::Registry< ::xgboost::MetricReg>::Get()->Find(prefix.c_str());
if (e == nullptr) {
LOG(FATAL) << "Unknown metric function " << name;
}
auto p_metric = (e->body)(nullptr);
auto p_metric = (e->body)(param);
p_metric->tparam_ = tparam;
return p_metric;
} else {

View File

@ -285,10 +285,13 @@ struct EvalRankList : public Metric {
minus_ = false;
if (param != nullptr) {
std::ostringstream os;
os << name << '@' << param;
name_ = os.str();
if (sscanf(param, "%u[-]?", &topn_) != 1) {
if (sscanf(param, "%u[-]?", &topn_) == 1) {
os << name << '@' << param;
name_ = os.str();
} else {
topn_ = std::numeric_limits<unsigned>::max();
os << name << param;
name_ = os.str();
}
if (param[strlen(param) - 1] == '-') {
minus_ = true;

View File

@ -113,7 +113,18 @@ TEST(Metric, NDCG) {
delete metric;
metric = xgboost::Metric::Create("ndcg@-", &tparam);
ASSERT_STREQ(metric->Name(), "ndcg@-");
ASSERT_STREQ(metric->Name(), "ndcg-");
EXPECT_NEAR(GetMetricEval(metric,
xgboost::HostDeviceVector<xgboost::bst_float>{},
{}), 0, 1e-10);
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10);
EXPECT_NEAR(GetMetricEval(metric,
{0.1f, 0.9f, 0.1f, 0.9f},
{ 0, 0, 1, 1}),
0.6509f, 0.001f);
delete metric;
metric = xgboost::Metric::Create("ndcg-", &tparam);
ASSERT_STREQ(metric->Name(), "ndcg-");
EXPECT_NEAR(GetMetricEval(metric,
xgboost::HostDeviceVector<xgboost::bst_float>{},
{}), 0, 1e-10);
@ -150,7 +161,14 @@ TEST(Metric, MAP) {
delete metric;
metric = xgboost::Metric::Create("map@-", &tparam);
ASSERT_STREQ(metric->Name(), "map@-");
ASSERT_STREQ(metric->Name(), "map-");
EXPECT_NEAR(GetMetricEval(metric,
xgboost::HostDeviceVector<xgboost::bst_float>{},
{}), 0, 1e-10);
delete metric;
metric = xgboost::Metric::Create("map-", &tparam);
ASSERT_STREQ(metric->Name(), "map-");
EXPECT_NEAR(GetMetricEval(metric,
xgboost::HostDeviceVector<xgboost::bst_float>{},
{}), 0, 1e-10);