rank_metric: add AUC-PR (#3172)
* rank_metric: add AUC-PR Implementation of the AUC-PR calculation for weighted data, proposed by Keilwagen, Grosse and Grau (https://doi.org/10.1371/journal.pone.0092209) * rank_metric: fix lint warnings * Implement tests for AUC-PR and fix implementation * add aucpr to documentation for other languages
This commit is contained in:
parent
8fb3388af2
commit
04221a7469
@ -34,6 +34,7 @@
|
||||
#' \item \code{rmse} Rooted mean square error
|
||||
#' \item \code{logloss} negative log-likelihood function
|
||||
#' \item \code{auc} Area under curve
|
||||
#' \item \code{aucpr} Area under PR curve
|
||||
#' \item \code{merror} Exact matching error, used to evaluate multi-class classification
|
||||
#' }
|
||||
#' @param obj customized objective function. Returns gradient and second order
|
||||
|
||||
@ -127,6 +127,7 @@
|
||||
#' Different threshold (e.g., 0.) could be specified as "error@0."
|
||||
#' \item \code{merror} Multiclass classification error rate. It is calculated as \code{(# wrong cases) / (# all cases)}.
|
||||
#' \item \code{auc} Area under the curve. \url{http://en.wikipedia.org/wiki/Receiver_operating_characteristic#'Area_under_curve} for ranking evaluation.
|
||||
#' \item \code{aucpr} Area under the PR curve. \url{https://en.wikipedia.org/wiki/Precision_and_recall} for ranking evaluation.
|
||||
#' \item \code{ndcg} Normalized Discounted Cumulative Gain (for ranking task). \url{http://en.wikipedia.org/wiki/NDCG}
|
||||
#' }
|
||||
#'
|
||||
|
||||
@ -51,6 +51,7 @@ from each CV model. This parameter engages the \code{\link{cb.cv.predict}} callb
|
||||
\item \code{rmse} Rooted mean square error
|
||||
\item \code{logloss} negative log-likelihood function
|
||||
\item \code{auc} Area under curve
|
||||
\item \code{aucpr} Area under PR curve
|
||||
\item \code{merror} Exact matching error, used to evaluate multi-class classification
|
||||
}}
|
||||
|
||||
|
||||
@ -186,6 +186,7 @@ The folloiwing is the list of built-in metrics for which Xgboost provides optimi
|
||||
Different threshold (e.g., 0.) could be specified as "error@0."
|
||||
\item \code{merror} Multiclass classification error rate. It is calculated as \code{(# wrong cases) / (# all cases)}.
|
||||
\item \code{auc} Area under the curve. \url{http://en.wikipedia.org/wiki/Receiver_operating_characteristic#'Area_under_curve} for ranking evaluation.
|
||||
\item \code{aucpr} Area under the PR curve. \url{https://en.wikipedia.org/wiki/Precision_and_recall} for ranking evaluation.
|
||||
\item \code{ndcg} Normalized Discounted Cumulative Gain (for ranking task). \url{http://en.wikipedia.org/wiki/NDCG}
|
||||
}
|
||||
|
||||
|
||||
@ -45,7 +45,8 @@ trait LearningTaskParams extends Params {
|
||||
/**
|
||||
* evaluation metrics for validation data, a default metric will be assigned according to
|
||||
* objective(rmse for regression, and error for classification, mean average precision for
|
||||
* ranking). options: rmse, mae, logloss, error, merror, mlogloss, auc, ndcg, map, gamma-deviance
|
||||
* ranking). options: rmse, mae, logloss, error, merror, mlogloss, auc, aucpr, ndcg, map,
|
||||
* gamma-deviance
|
||||
*/
|
||||
val evalMetric = new Param[String](this, "eval_metric", "evaluation metrics for validation" +
|
||||
" data, a default metric will be assigned according to objective (rmse for regression, and" +
|
||||
@ -97,5 +98,5 @@ private[spark] object LearningTaskParams {
|
||||
"reg:gamma")
|
||||
|
||||
val supportedEvalMetrics = HashSet("rmse", "mae", "logloss", "error", "merror", "mlogloss",
|
||||
"auc", "ndcg", "map", "gamma-deviance")
|
||||
"auc", "aucpr", "ndcg", "map", "gamma-deviance")
|
||||
}
|
||||
|
||||
@ -350,6 +350,94 @@ struct EvalCox : public Metric {
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief Area Under PR Curve, for both classification and rank */
|
||||
struct EvalAucPR : public Metric {
|
||||
// implementation of AUC-PR for weighted data
|
||||
// translated from PRROC R Package
|
||||
// see https://doi.org/10.1371/journal.pone.0092209
|
||||
|
||||
bst_float Eval(const std::vector<bst_float> &preds, const MetaInfo &info,
|
||||
bool distributed) const override {
|
||||
CHECK_NE(info.labels.size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.size(), info.labels.size())
|
||||
<< "label size predict size not match";
|
||||
std::vector<unsigned> tgptr(2, 0);
|
||||
tgptr[1] = static_cast<unsigned>(info.labels.size());
|
||||
const std::vector<unsigned> &gptr =
|
||||
info.group_ptr.size() == 0 ? tgptr : info.group_ptr;
|
||||
CHECK_EQ(gptr.back(), info.labels.size())
|
||||
<< "EvalAucPR: group structure must match number of prediction";
|
||||
const bst_omp_uint ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
|
||||
// sum statistics
|
||||
double auc = 0.0;
|
||||
int auc_error = 0, auc_gt_one = 0;
|
||||
// each thread takes a local rec
|
||||
std::vector<std::pair<bst_float, unsigned>> rec;
|
||||
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
||||
double total_pos = 0.0;
|
||||
double total_neg = 0.0;
|
||||
rec.clear();
|
||||
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
|
||||
total_pos += info.GetWeight(j) * info.labels[j];
|
||||
total_neg += info.GetWeight(j) * (1.0f - info.labels[j]);
|
||||
rec.push_back(std::make_pair(preds[j], j));
|
||||
}
|
||||
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||
// we need pos > 0 && neg > 0
|
||||
if (0.0 == total_pos || 0.0 == total_neg) {
|
||||
auc_error = 1;
|
||||
}
|
||||
// calculate AUC
|
||||
double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0;
|
||||
for (size_t j = 0; j < rec.size(); ++j) {
|
||||
tp += info.GetWeight(rec[j].second) * info.labels[rec[j].second];
|
||||
fp += info.GetWeight(rec[j].second) * (1.0f - info.labels[rec[j].second]);
|
||||
if ((j < rec.size() - 1 && rec[j].first != rec[j + 1].first) || j == rec.size() - 1) {
|
||||
if (tp == prevtp) {
|
||||
h = 1.0;
|
||||
a = 1.0;
|
||||
b = 0.0;
|
||||
} else {
|
||||
h = (fp - prevfp) / (tp - prevtp);
|
||||
a = 1.0 + h;
|
||||
b = (prevfp - h * prevtp) / total_pos;
|
||||
}
|
||||
if (0.0 != b) {
|
||||
auc += (tp / total_pos - prevtp / total_pos -
|
||||
b / a * (std::log(a * tp / total_pos + b) -
|
||||
std::log(a * prevtp / total_pos + b))) / a;
|
||||
} else {
|
||||
auc += (tp / total_pos - prevtp / total_pos) / a;
|
||||
}
|
||||
if (auc > 1.0) {
|
||||
auc_gt_one = 1;
|
||||
}
|
||||
prevtp = tp;
|
||||
prevfp = fp;
|
||||
}
|
||||
}
|
||||
// sanity check
|
||||
if (tp < 0 || prevtp < 0 || fp < 0 || prevfp < 0) {
|
||||
CHECK(!auc_error) << "AUC-PR: error in calculation";
|
||||
}
|
||||
}
|
||||
CHECK(!auc_error) << "AUC-PR: the dataset only contains pos or neg samples";
|
||||
CHECK(!auc_gt_one) << "AUC-PR: AUC > 1.0";
|
||||
if (distributed) {
|
||||
bst_float dat[2];
|
||||
dat[0] = static_cast<bst_float>(auc);
|
||||
dat[1] = static_cast<bst_float>(ngroup);
|
||||
// approximately estimate auc using mean
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
return dat[0] / dat[1];
|
||||
} else {
|
||||
return static_cast<bst_float>(auc) / ngroup;
|
||||
}
|
||||
}
|
||||
const char *Name() const override { return "aucpr"; }
|
||||
};
|
||||
|
||||
|
||||
XGBOOST_REGISTER_METRIC(AMS, "ams")
|
||||
.describe("AMS metric for higgs.")
|
||||
.set_body([](const char* param) { return new EvalAMS(param); });
|
||||
@ -358,6 +446,10 @@ XGBOOST_REGISTER_METRIC(Auc, "auc")
|
||||
.describe("Area under curve for both classification and rank.")
|
||||
.set_body([](const char* param) { return new EvalAuc(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(AucPR, "aucpr")
|
||||
.describe("Area under PR curve for both classification and rank.")
|
||||
.set_body([](const char* param) { return new EvalAucPR(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(Precision, "pre")
|
||||
.describe("precision@k for rank.")
|
||||
.set_body([](const char* param) { return new EvalPrecision(param); });
|
||||
@ -375,3 +467,4 @@ XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik")
|
||||
.set_body([](const char* param) { return new EvalCox(); });
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@ -31,6 +31,27 @@ TEST(Metric, AUC) {
|
||||
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 0}, {0, 0}));
|
||||
}
|
||||
|
||||
TEST(Metric, AUCPR) {
|
||||
xgboost::Metric *metric = xgboost::Metric::Create("aucpr");
|
||||
ASSERT_STREQ(metric->Name(), "aucpr");
|
||||
EXPECT_NEAR(GetMetricEval(metric, {0, 0, 1, 1}, {0, 0, 1, 1}), 1, 1e-10);
|
||||
EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, {0, 0, 1, 1}),
|
||||
0.5f, 0.001f);
|
||||
EXPECT_NEAR(
|
||||
GetMetricEval(metric,
|
||||
{0.4f, 0.2f, 0.9f, 0.1f, 0.2f, 0.4f, 0.1f, 0.1f, 0.2f, 0.1},
|
||||
{0, 0, 0, 0, 0, 1, 0, 0, 1, 1}),
|
||||
0.2908445f, 0.001f);
|
||||
EXPECT_NEAR(GetMetricEval(
|
||||
metric, {0.87f, 0.31f, 0.40f, 0.42f, 0.25f, 0.66f, 0.95f,
|
||||
0.09f, 0.10f, 0.97f, 0.76f, 0.69f, 0.15f, 0.20f,
|
||||
0.30f, 0.14f, 0.07f, 0.58f, 0.61f, 0.08f},
|
||||
{0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1}),
|
||||
0.2769199f, 0.001f);
|
||||
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {}));
|
||||
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 0}, {0, 0}));
|
||||
}
|
||||
|
||||
TEST(Metric, Precision) {
|
||||
// When the limit for precision is not given, it takes the limit at
|
||||
// std::numeric_limits<unsigned>::max(); hence all values are very small
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user