rank_metric: add AUC-PR (#3172)

* rank_metric: add AUC-PR

Implementation of the AUC-PR calculation for weighted data, proposed by Keilwagen, Grosse and Grau (https://doi.org/10.1371/journal.pone.0092209)

* rank_metric: fix lint warnings

* Implement tests for AUC-PR and fix implementation

* add aucpr to documentation for other languages
This commit is contained in:
Arjan van der Velde 2018-03-23 10:43:47 -04:00 committed by Yuan (Terry) Tang
parent 8fb3388af2
commit 04221a7469
7 changed files with 121 additions and 2 deletions

View File

@ -34,6 +34,7 @@
#' \item \code{rmse} Rooted mean square error #' \item \code{rmse} Rooted mean square error
#' \item \code{logloss} negative log-likelihood function #' \item \code{logloss} negative log-likelihood function
#' \item \code{auc} Area under curve #' \item \code{auc} Area under curve
#' \item \code{aucpr} Area under PR curve
#' \item \code{merror} Exact matching error, used to evaluate multi-class classification #' \item \code{merror} Exact matching error, used to evaluate multi-class classification
#' } #' }
#' @param obj customized objective function. Returns gradient and second order #' @param obj customized objective function. Returns gradient and second order

View File

@ -127,6 +127,7 @@
#' Different threshold (e.g., 0.) could be specified as "error@0." #' Different threshold (e.g., 0.) could be specified as "error@0."
#' \item \code{merror} Multiclass classification error rate. It is calculated as \code{(# wrong cases) / (# all cases)}. #' \item \code{merror} Multiclass classification error rate. It is calculated as \code{(# wrong cases) / (# all cases)}.
#' \item \code{auc} Area under the curve. \url{http://en.wikipedia.org/wiki/Receiver_operating_characteristic#'Area_under_curve} for ranking evaluation. #' \item \code{auc} Area under the curve. \url{http://en.wikipedia.org/wiki/Receiver_operating_characteristic#'Area_under_curve} for ranking evaluation.
#' \item \code{aucpr} Area under the PR curve. \url{https://en.wikipedia.org/wiki/Precision_and_recall} for ranking evaluation.
#' \item \code{ndcg} Normalized Discounted Cumulative Gain (for ranking task). \url{http://en.wikipedia.org/wiki/NDCG} #' \item \code{ndcg} Normalized Discounted Cumulative Gain (for ranking task). \url{http://en.wikipedia.org/wiki/NDCG}
#' } #' }
#' #'

View File

@ -51,6 +51,7 @@ from each CV model. This parameter engages the \code{\link{cb.cv.predict}} callb
\item \code{rmse} Rooted mean square error \item \code{rmse} Rooted mean square error
\item \code{logloss} negative log-likelihood function \item \code{logloss} negative log-likelihood function
\item \code{auc} Area under curve \item \code{auc} Area under curve
\item \code{aucpr} Area under PR curve
\item \code{merror} Exact matching error, used to evaluate multi-class classification \item \code{merror} Exact matching error, used to evaluate multi-class classification
}} }}

View File

@ -186,6 +186,7 @@ The folloiwing is the list of built-in metrics for which Xgboost provides optimi
Different threshold (e.g., 0.) could be specified as "error@0." Different threshold (e.g., 0.) could be specified as "error@0."
\item \code{merror} Multiclass classification error rate. It is calculated as \code{(# wrong cases) / (# all cases)}. \item \code{merror} Multiclass classification error rate. It is calculated as \code{(# wrong cases) / (# all cases)}.
\item \code{auc} Area under the curve. \url{http://en.wikipedia.org/wiki/Receiver_operating_characteristic#'Area_under_curve} for ranking evaluation. \item \code{auc} Area under the curve. \url{http://en.wikipedia.org/wiki/Receiver_operating_characteristic#'Area_under_curve} for ranking evaluation.
\item \code{aucpr} Area under the PR curve. \url{https://en.wikipedia.org/wiki/Precision_and_recall} for ranking evaluation.
\item \code{ndcg} Normalized Discounted Cumulative Gain (for ranking task). \url{http://en.wikipedia.org/wiki/NDCG} \item \code{ndcg} Normalized Discounted Cumulative Gain (for ranking task). \url{http://en.wikipedia.org/wiki/NDCG}
} }

View File

@ -45,7 +45,8 @@ trait LearningTaskParams extends Params {
/** /**
* evaluation metrics for validation data, a default metric will be assigned according to * evaluation metrics for validation data, a default metric will be assigned according to
* objective(rmse for regression, and error for classification, mean average precision for * objective(rmse for regression, and error for classification, mean average precision for
* ranking). options: rmse, mae, logloss, error, merror, mlogloss, auc, ndcg, map, gamma-deviance * ranking). options: rmse, mae, logloss, error, merror, mlogloss, auc, aucpr, ndcg, map,
* gamma-deviance
*/ */
val evalMetric = new Param[String](this, "eval_metric", "evaluation metrics for validation" + val evalMetric = new Param[String](this, "eval_metric", "evaluation metrics for validation" +
" data, a default metric will be assigned according to objective (rmse for regression, and" + " data, a default metric will be assigned according to objective (rmse for regression, and" +
@ -97,5 +98,5 @@ private[spark] object LearningTaskParams {
"reg:gamma") "reg:gamma")
val supportedEvalMetrics = HashSet("rmse", "mae", "logloss", "error", "merror", "mlogloss", val supportedEvalMetrics = HashSet("rmse", "mae", "logloss", "error", "merror", "mlogloss",
"auc", "ndcg", "map", "gamma-deviance") "auc", "aucpr", "ndcg", "map", "gamma-deviance")
} }

View File

@ -350,6 +350,94 @@ struct EvalCox : public Metric {
} }
}; };
/*! \brief Area Under PR Curve, for both classification and rank */
struct EvalAucPR : public Metric {
// implementation of AUC-PR for weighted data
// translated from PRROC R Package
// see https://doi.org/10.1371/journal.pone.0092209
bst_float Eval(const std::vector<bst_float> &preds, const MetaInfo &info,
bool distributed) const override {
CHECK_NE(info.labels.size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.size(), info.labels.size())
<< "label size predict size not match";
std::vector<unsigned> tgptr(2, 0);
tgptr[1] = static_cast<unsigned>(info.labels.size());
const std::vector<unsigned> &gptr =
info.group_ptr.size() == 0 ? tgptr : info.group_ptr;
CHECK_EQ(gptr.back(), info.labels.size())
<< "EvalAucPR: group structure must match number of prediction";
const bst_omp_uint ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
// sum statistics
double auc = 0.0;
int auc_error = 0, auc_gt_one = 0;
// each thread takes a local rec
std::vector<std::pair<bst_float, unsigned>> rec;
for (bst_omp_uint k = 0; k < ngroup; ++k) {
double total_pos = 0.0;
double total_neg = 0.0;
rec.clear();
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
total_pos += info.GetWeight(j) * info.labels[j];
total_neg += info.GetWeight(j) * (1.0f - info.labels[j]);
rec.push_back(std::make_pair(preds[j], j));
}
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
// we need pos > 0 && neg > 0
if (0.0 == total_pos || 0.0 == total_neg) {
auc_error = 1;
}
// calculate AUC
double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0;
for (size_t j = 0; j < rec.size(); ++j) {
tp += info.GetWeight(rec[j].second) * info.labels[rec[j].second];
fp += info.GetWeight(rec[j].second) * (1.0f - info.labels[rec[j].second]);
if ((j < rec.size() - 1 && rec[j].first != rec[j + 1].first) || j == rec.size() - 1) {
if (tp == prevtp) {
h = 1.0;
a = 1.0;
b = 0.0;
} else {
h = (fp - prevfp) / (tp - prevtp);
a = 1.0 + h;
b = (prevfp - h * prevtp) / total_pos;
}
if (0.0 != b) {
auc += (tp / total_pos - prevtp / total_pos -
b / a * (std::log(a * tp / total_pos + b) -
std::log(a * prevtp / total_pos + b))) / a;
} else {
auc += (tp / total_pos - prevtp / total_pos) / a;
}
if (auc > 1.0) {
auc_gt_one = 1;
}
prevtp = tp;
prevfp = fp;
}
}
// sanity check
if (tp < 0 || prevtp < 0 || fp < 0 || prevfp < 0) {
CHECK(!auc_error) << "AUC-PR: error in calculation";
}
}
CHECK(!auc_error) << "AUC-PR: the dataset only contains pos or neg samples";
CHECK(!auc_gt_one) << "AUC-PR: AUC > 1.0";
if (distributed) {
bst_float dat[2];
dat[0] = static_cast<bst_float>(auc);
dat[1] = static_cast<bst_float>(ngroup);
// approximately estimate auc using mean
rabit::Allreduce<rabit::op::Sum>(dat, 2);
return dat[0] / dat[1];
} else {
return static_cast<bst_float>(auc) / ngroup;
}
}
const char *Name() const override { return "aucpr"; }
};
XGBOOST_REGISTER_METRIC(AMS, "ams") XGBOOST_REGISTER_METRIC(AMS, "ams")
.describe("AMS metric for higgs.") .describe("AMS metric for higgs.")
.set_body([](const char* param) { return new EvalAMS(param); }); .set_body([](const char* param) { return new EvalAMS(param); });
@ -358,6 +446,10 @@ XGBOOST_REGISTER_METRIC(Auc, "auc")
.describe("Area under curve for both classification and rank.") .describe("Area under curve for both classification and rank.")
.set_body([](const char* param) { return new EvalAuc(); }); .set_body([](const char* param) { return new EvalAuc(); });
XGBOOST_REGISTER_METRIC(AucPR, "aucpr")
.describe("Area under PR curve for both classification and rank.")
.set_body([](const char* param) { return new EvalAucPR(); });
XGBOOST_REGISTER_METRIC(Precision, "pre") XGBOOST_REGISTER_METRIC(Precision, "pre")
.describe("precision@k for rank.") .describe("precision@k for rank.")
.set_body([](const char* param) { return new EvalPrecision(param); }); .set_body([](const char* param) { return new EvalPrecision(param); });
@ -375,3 +467,4 @@ XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik")
.set_body([](const char* param) { return new EvalCox(); }); .set_body([](const char* param) { return new EvalCox(); });
} // namespace metric } // namespace metric
} // namespace xgboost } // namespace xgboost

View File

@ -31,6 +31,27 @@ TEST(Metric, AUC) {
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 0}, {0, 0})); EXPECT_ANY_THROW(GetMetricEval(metric, {0, 0}, {0, 0}));
} }
TEST(Metric, AUCPR) {
xgboost::Metric *metric = xgboost::Metric::Create("aucpr");
ASSERT_STREQ(metric->Name(), "aucpr");
EXPECT_NEAR(GetMetricEval(metric, {0, 0, 1, 1}, {0, 0, 1, 1}), 1, 1e-10);
EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, {0, 0, 1, 1}),
0.5f, 0.001f);
EXPECT_NEAR(
GetMetricEval(metric,
{0.4f, 0.2f, 0.9f, 0.1f, 0.2f, 0.4f, 0.1f, 0.1f, 0.2f, 0.1},
{0, 0, 0, 0, 0, 1, 0, 0, 1, 1}),
0.2908445f, 0.001f);
EXPECT_NEAR(GetMetricEval(
metric, {0.87f, 0.31f, 0.40f, 0.42f, 0.25f, 0.66f, 0.95f,
0.09f, 0.10f, 0.97f, 0.76f, 0.69f, 0.15f, 0.20f,
0.30f, 0.14f, 0.07f, 0.58f, 0.61f, 0.08f},
{0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1}),
0.2769199f, 0.001f);
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {}));
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 0}, {0, 0}));
}
TEST(Metric, Precision) { TEST(Metric, Precision) {
// When the limit for precision is not given, it takes the limit at // When the limit for precision is not given, it takes the limit at
// std::numeric_limits<unsigned>::max(); hence all values are very small // std::numeric_limits<unsigned>::max(); hence all values are very small