Support ndcg- and map- (#4635)
This commit is contained in:
parent
4e9fad74eb
commit
96bf91725b
@ -15,13 +15,22 @@ namespace xgboost {
|
||||
Metric* Metric::Create(const std::string& name, LearnerTrainParam const* tparam) {
|
||||
std::string buf = name;
|
||||
std::string prefix = name;
|
||||
const char* param;
|
||||
auto pos = buf.find('@');
|
||||
if (pos == std::string::npos) {
|
||||
auto *e = ::dmlc::Registry< ::xgboost::MetricReg>::Get()->Find(name);
|
||||
if (!buf.empty() && buf.back() == '-') {
|
||||
// Metrics of form "metric-"
|
||||
prefix = buf.substr(0, buf.length() - 1); // Chop off '-'
|
||||
param = "-";
|
||||
} else {
|
||||
prefix = buf;
|
||||
param = nullptr;
|
||||
}
|
||||
auto *e = ::dmlc::Registry< ::xgboost::MetricReg>::Get()->Find(prefix.c_str());
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown metric function " << name;
|
||||
}
|
||||
auto p_metric = (e->body)(nullptr);
|
||||
auto p_metric = (e->body)(param);
|
||||
p_metric->tparam_ = tparam;
|
||||
return p_metric;
|
||||
} else {
|
||||
|
||||
@ -285,10 +285,13 @@ struct EvalRankList : public Metric {
|
||||
minus_ = false;
|
||||
if (param != nullptr) {
|
||||
std::ostringstream os;
|
||||
os << name << '@' << param;
|
||||
name_ = os.str();
|
||||
if (sscanf(param, "%u[-]?", &topn_) != 1) {
|
||||
if (sscanf(param, "%u[-]?", &topn_) == 1) {
|
||||
os << name << '@' << param;
|
||||
name_ = os.str();
|
||||
} else {
|
||||
topn_ = std::numeric_limits<unsigned>::max();
|
||||
os << name << param;
|
||||
name_ = os.str();
|
||||
}
|
||||
if (param[strlen(param) - 1] == '-') {
|
||||
minus_ = true;
|
||||
|
||||
@ -113,7 +113,18 @@ TEST(Metric, NDCG) {
|
||||
|
||||
delete metric;
|
||||
metric = xgboost::Metric::Create("ndcg@-", &tparam);
|
||||
ASSERT_STREQ(metric->Name(), "ndcg@-");
|
||||
ASSERT_STREQ(metric->Name(), "ndcg-");
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
xgboost::HostDeviceVector<xgboost::bst_float>{},
|
||||
{}), 0, 1e-10);
|
||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10);
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
{0.1f, 0.9f, 0.1f, 0.9f},
|
||||
{ 0, 0, 1, 1}),
|
||||
0.6509f, 0.001f);
|
||||
delete metric;
|
||||
metric = xgboost::Metric::Create("ndcg-", &tparam);
|
||||
ASSERT_STREQ(metric->Name(), "ndcg-");
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
xgboost::HostDeviceVector<xgboost::bst_float>{},
|
||||
{}), 0, 1e-10);
|
||||
@ -150,7 +161,14 @@ TEST(Metric, MAP) {
|
||||
|
||||
delete metric;
|
||||
metric = xgboost::Metric::Create("map@-", &tparam);
|
||||
ASSERT_STREQ(metric->Name(), "map@-");
|
||||
ASSERT_STREQ(metric->Name(), "map-");
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
xgboost::HostDeviceVector<xgboost::bst_float>{},
|
||||
{}), 0, 1e-10);
|
||||
|
||||
delete metric;
|
||||
metric = xgboost::Metric::Create("map-", &tparam);
|
||||
ASSERT_STREQ(metric->Name(), "map-");
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
xgboost::HostDeviceVector<xgboost::bst_float>{},
|
||||
{}), 0, 1e-10);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user