rank_metric: add AUC-PR (#3172)

* rank_metric: add AUC-PR

Implementation of the AUC-PR calculation for weighted data, proposed by Keilwagen, Grosse and Grau (https://doi.org/10.1371/journal.pone.0092209)

* rank_metric: fix lint warnings

* Implement tests for AUC-PR and fix implementation

* add aucpr to documentation for other languages
This commit is contained in:
Arjan van der Velde
2018-03-23 10:43:47 -04:00
committed by Yuan (Terry) Tang
parent 8fb3388af2
commit 04221a7469
7 changed files with 121 additions and 2 deletions

View File

@@ -350,6 +350,94 @@ struct EvalCox : public Metric {
}
};
/*! \brief Area Under PR Curve, for both classification and rank */
struct EvalAucPR : public Metric {
// implementation of AUC-PR for weighted data
// translated from PRROC R Package
// see https://doi.org/10.1371/journal.pone.0092209
bst_float Eval(const std::vector<bst_float> &preds, const MetaInfo &info,
bool distributed) const override {
CHECK_NE(info.labels.size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.size(), info.labels.size())
<< "label size predict size not match";
std::vector<unsigned> tgptr(2, 0);
tgptr[1] = static_cast<unsigned>(info.labels.size());
const std::vector<unsigned> &gptr =
info.group_ptr.size() == 0 ? tgptr : info.group_ptr;
CHECK_EQ(gptr.back(), info.labels.size())
<< "EvalAucPR: group structure must match number of prediction";
const bst_omp_uint ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
// sum statistics
double auc = 0.0;
int auc_error = 0, auc_gt_one = 0;
// each thread takes a local rec
std::vector<std::pair<bst_float, unsigned>> rec;
for (bst_omp_uint k = 0; k < ngroup; ++k) {
double total_pos = 0.0;
double total_neg = 0.0;
rec.clear();
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
total_pos += info.GetWeight(j) * info.labels[j];
total_neg += info.GetWeight(j) * (1.0f - info.labels[j]);
rec.push_back(std::make_pair(preds[j], j));
}
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
// we need pos > 0 && neg > 0
if (0.0 == total_pos || 0.0 == total_neg) {
auc_error = 1;
}
// calculate AUC
double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0;
for (size_t j = 0; j < rec.size(); ++j) {
tp += info.GetWeight(rec[j].second) * info.labels[rec[j].second];
fp += info.GetWeight(rec[j].second) * (1.0f - info.labels[rec[j].second]);
if ((j < rec.size() - 1 && rec[j].first != rec[j + 1].first) || j == rec.size() - 1) {
if (tp == prevtp) {
h = 1.0;
a = 1.0;
b = 0.0;
} else {
h = (fp - prevfp) / (tp - prevtp);
a = 1.0 + h;
b = (prevfp - h * prevtp) / total_pos;
}
if (0.0 != b) {
auc += (tp / total_pos - prevtp / total_pos -
b / a * (std::log(a * tp / total_pos + b) -
std::log(a * prevtp / total_pos + b))) / a;
} else {
auc += (tp / total_pos - prevtp / total_pos) / a;
}
if (auc > 1.0) {
auc_gt_one = 1;
}
prevtp = tp;
prevfp = fp;
}
}
// sanity check
if (tp < 0 || prevtp < 0 || fp < 0 || prevfp < 0) {
CHECK(!auc_error) << "AUC-PR: error in calculation";
}
}
CHECK(!auc_error) << "AUC-PR: the dataset only contains pos or neg samples";
CHECK(!auc_gt_one) << "AUC-PR: AUC > 1.0";
if (distributed) {
bst_float dat[2];
dat[0] = static_cast<bst_float>(auc);
dat[1] = static_cast<bst_float>(ngroup);
// approximately estimate auc using mean
rabit::Allreduce<rabit::op::Sum>(dat, 2);
return dat[0] / dat[1];
} else {
return static_cast<bst_float>(auc) / ngroup;
}
}
const char *Name() const override { return "aucpr"; }
};
XGBOOST_REGISTER_METRIC(AMS, "ams")
.describe("AMS metric for higgs.")
.set_body([](const char* param) { return new EvalAMS(param); });
@@ -358,6 +446,10 @@ XGBOOST_REGISTER_METRIC(Auc, "auc")
.describe("Area under curve for both classification and rank.")
.set_body([](const char* param) { return new EvalAuc(); });
XGBOOST_REGISTER_METRIC(AucPR, "aucpr")
.describe("Area under PR curve for both classification and rank.")
.set_body([](const char* param) { return new EvalAucPR(); });
XGBOOST_REGISTER_METRIC(Precision, "pre")
.describe("precision@k for rank.")
.set_body([](const char* param) { return new EvalPrecision(param); });
@@ -375,3 +467,4 @@ XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik")
.set_body([](const char* param) { return new EvalCox(); });
} // namespace metric
} // namespace xgboost