rank_metric: add AUC-PR (#3172)
* rank_metric: add AUC-PR Implementation of the AUC-PR calculation for weighted data, proposed by Keilwagen, Grosse and Grau (https://doi.org/10.1371/journal.pone.0092209) * rank_metric: fix lint warnings * Implement tests for AUC-PR and fix implementation * add aucpr to documentation for other languages
This commit is contained in:
committed by
Yuan (Terry) Tang
parent
8fb3388af2
commit
04221a7469
@@ -350,6 +350,94 @@ struct EvalCox : public Metric {
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief Area Under PR Curve, for both classification and rank */
|
||||
struct EvalAucPR : public Metric {
|
||||
// implementation of AUC-PR for weighted data
|
||||
// translated from PRROC R Package
|
||||
// see https://doi.org/10.1371/journal.pone.0092209
|
||||
|
||||
bst_float Eval(const std::vector<bst_float> &preds, const MetaInfo &info,
|
||||
bool distributed) const override {
|
||||
CHECK_NE(info.labels.size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.size(), info.labels.size())
|
||||
<< "label size predict size not match";
|
||||
std::vector<unsigned> tgptr(2, 0);
|
||||
tgptr[1] = static_cast<unsigned>(info.labels.size());
|
||||
const std::vector<unsigned> &gptr =
|
||||
info.group_ptr.size() == 0 ? tgptr : info.group_ptr;
|
||||
CHECK_EQ(gptr.back(), info.labels.size())
|
||||
<< "EvalAucPR: group structure must match number of prediction";
|
||||
const bst_omp_uint ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
|
||||
// sum statistics
|
||||
double auc = 0.0;
|
||||
int auc_error = 0, auc_gt_one = 0;
|
||||
// each thread takes a local rec
|
||||
std::vector<std::pair<bst_float, unsigned>> rec;
|
||||
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
||||
double total_pos = 0.0;
|
||||
double total_neg = 0.0;
|
||||
rec.clear();
|
||||
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
|
||||
total_pos += info.GetWeight(j) * info.labels[j];
|
||||
total_neg += info.GetWeight(j) * (1.0f - info.labels[j]);
|
||||
rec.push_back(std::make_pair(preds[j], j));
|
||||
}
|
||||
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||
// we need pos > 0 && neg > 0
|
||||
if (0.0 == total_pos || 0.0 == total_neg) {
|
||||
auc_error = 1;
|
||||
}
|
||||
// calculate AUC
|
||||
double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0;
|
||||
for (size_t j = 0; j < rec.size(); ++j) {
|
||||
tp += info.GetWeight(rec[j].second) * info.labels[rec[j].second];
|
||||
fp += info.GetWeight(rec[j].second) * (1.0f - info.labels[rec[j].second]);
|
||||
if ((j < rec.size() - 1 && rec[j].first != rec[j + 1].first) || j == rec.size() - 1) {
|
||||
if (tp == prevtp) {
|
||||
h = 1.0;
|
||||
a = 1.0;
|
||||
b = 0.0;
|
||||
} else {
|
||||
h = (fp - prevfp) / (tp - prevtp);
|
||||
a = 1.0 + h;
|
||||
b = (prevfp - h * prevtp) / total_pos;
|
||||
}
|
||||
if (0.0 != b) {
|
||||
auc += (tp / total_pos - prevtp / total_pos -
|
||||
b / a * (std::log(a * tp / total_pos + b) -
|
||||
std::log(a * prevtp / total_pos + b))) / a;
|
||||
} else {
|
||||
auc += (tp / total_pos - prevtp / total_pos) / a;
|
||||
}
|
||||
if (auc > 1.0) {
|
||||
auc_gt_one = 1;
|
||||
}
|
||||
prevtp = tp;
|
||||
prevfp = fp;
|
||||
}
|
||||
}
|
||||
// sanity check
|
||||
if (tp < 0 || prevtp < 0 || fp < 0 || prevfp < 0) {
|
||||
CHECK(!auc_error) << "AUC-PR: error in calculation";
|
||||
}
|
||||
}
|
||||
CHECK(!auc_error) << "AUC-PR: the dataset only contains pos or neg samples";
|
||||
CHECK(!auc_gt_one) << "AUC-PR: AUC > 1.0";
|
||||
if (distributed) {
|
||||
bst_float dat[2];
|
||||
dat[0] = static_cast<bst_float>(auc);
|
||||
dat[1] = static_cast<bst_float>(ngroup);
|
||||
// approximately estimate auc using mean
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
return dat[0] / dat[1];
|
||||
} else {
|
||||
return static_cast<bst_float>(auc) / ngroup;
|
||||
}
|
||||
}
|
||||
const char *Name() const override { return "aucpr"; }
|
||||
};
|
||||
|
||||
|
||||
XGBOOST_REGISTER_METRIC(AMS, "ams")
|
||||
.describe("AMS metric for higgs.")
|
||||
.set_body([](const char* param) { return new EvalAMS(param); });
|
||||
@@ -358,6 +446,10 @@ XGBOOST_REGISTER_METRIC(Auc, "auc")
|
||||
.describe("Area under curve for both classification and rank.")
|
||||
.set_body([](const char* param) { return new EvalAuc(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(AucPR, "aucpr")
|
||||
.describe("Area under PR curve for both classification and rank.")
|
||||
.set_body([](const char* param) { return new EvalAucPR(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(Precision, "pre")
|
||||
.describe("precision@k for rank.")
|
||||
.set_body([](const char* param) { return new EvalPrecision(param); });
|
||||
@@ -375,3 +467,4 @@ XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik")
|
||||
.set_body([](const char* param) { return new EvalCox(); });
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
Reference in New Issue
Block a user