From 04221a7469dbd0b43b2719eede51387913580db3 Mon Sep 17 00:00:00 2001 From: Arjan van der Velde Date: Fri, 23 Mar 2018 10:43:47 -0400 Subject: [PATCH] rank_metric: add AUC-PR (#3172) * rank_metric: add AUC-PR Implementation of the AUC-PR calculation for weighted data, proposed by Keilwagen, Grosse and Grau (https://doi.org/10.1371/journal.pone.0092209) * rank_metric: fix lint warnings * Implement tests for AUC-PR and fix implementation * add aucpr to documentation for other languages --- R-package/R/xgb.cv.R | 1 + R-package/R/xgb.train.R | 1 + R-package/man/xgb.cv.Rd | 1 + R-package/man/xgb.train.Rd | 1 + .../spark/params/LearningTaskParams.scala | 5 +- src/metric/rank_metric.cc | 93 +++++++++++++++++++ tests/cpp/metric/test_rank_metric.cc | 21 +++++ 7 files changed, 121 insertions(+), 2 deletions(-) diff --git a/R-package/R/xgb.cv.R b/R-package/R/xgb.cv.R index 54c9f2d0b..652d52995 100644 --- a/R-package/R/xgb.cv.R +++ b/R-package/R/xgb.cv.R @@ -34,6 +34,7 @@ #' \item \code{rmse} Rooted mean square error #' \item \code{logloss} negative log-likelihood function #' \item \code{auc} Area under curve +#' \item \code{aucpr} Area under PR curve #' \item \code{merror} Exact matching error, used to evaluate multi-class classification #' } #' @param obj customized objective function. Returns gradient and second order diff --git a/R-package/R/xgb.train.R b/R-package/R/xgb.train.R index 26e6bc737..fa8285473 100644 --- a/R-package/R/xgb.train.R +++ b/R-package/R/xgb.train.R @@ -127,6 +127,7 @@ #' Different threshold (e.g., 0.) could be specified as "error@0." #' \item \code{merror} Multiclass classification error rate. It is calculated as \code{(# wrong cases) / (# all cases)}. #' \item \code{auc} Area under the curve. \url{http://en.wikipedia.org/wiki/Receiver_operating_characteristic#'Area_under_curve} for ranking evaluation. +#' \item \code{aucpr} Area under the PR curve. \url{https://en.wikipedia.org/wiki/Precision_and_recall} for ranking evaluation. #' \item \code{ndcg} Normalized Discounted Cumulative Gain (for ranking task). \url{http://en.wikipedia.org/wiki/NDCG} #' } #' diff --git a/R-package/man/xgb.cv.Rd b/R-package/man/xgb.cv.Rd index 31d41324a..bdac5a22a 100644 --- a/R-package/man/xgb.cv.Rd +++ b/R-package/man/xgb.cv.Rd @@ -51,6 +51,7 @@ from each CV model. This parameter engages the \code{\link{cb.cv.predict}} callb \item \code{rmse} Rooted mean square error \item \code{logloss} negative log-likelihood function \item \code{auc} Area under curve + \item \code{aucpr} Area under PR curve \item \code{merror} Exact matching error, used to evaluate multi-class classification }} diff --git a/R-package/man/xgb.train.Rd b/R-package/man/xgb.train.Rd index b93298911..868ad2034 100644 --- a/R-package/man/xgb.train.Rd +++ b/R-package/man/xgb.train.Rd @@ -186,6 +186,7 @@ The folloiwing is the list of built-in metrics for which Xgboost provides optimi Different threshold (e.g., 0.) could be specified as "error@0." \item \code{merror} Multiclass classification error rate. It is calculated as \code{(# wrong cases) / (# all cases)}. \item \code{auc} Area under the curve. \url{http://en.wikipedia.org/wiki/Receiver_operating_characteristic#'Area_under_curve} for ranking evaluation. + \item \code{aucpr} Area under the PR curve. \url{https://en.wikipedia.org/wiki/Precision_and_recall} for ranking evaluation. \item \code{ndcg} Normalized Discounted Cumulative Gain (for ranking task). \url{http://en.wikipedia.org/wiki/NDCG} } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index b86c0de0a..0c5055bf5 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -45,7 +45,8 @@ trait LearningTaskParams extends Params { /** * evaluation metrics for validation data, a default metric will be assigned according to * objective(rmse for regression, and error for classification, mean average precision for - * ranking). options: rmse, mae, logloss, error, merror, mlogloss, auc, ndcg, map, gamma-deviance + * ranking). options: rmse, mae, logloss, error, merror, mlogloss, auc, aucpr, ndcg, map, + * gamma-deviance */ val evalMetric = new Param[String](this, "eval_metric", "evaluation metrics for validation" + " data, a default metric will be assigned according to objective (rmse for regression, and" + @@ -97,5 +98,5 @@ private[spark] object LearningTaskParams { "reg:gamma") val supportedEvalMetrics = HashSet("rmse", "mae", "logloss", "error", "merror", "mlogloss", - "auc", "ndcg", "map", "gamma-deviance") + "auc", "aucpr", "ndcg", "map", "gamma-deviance") } diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index 032c08bd2..216169ca1 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -350,6 +350,94 @@ struct EvalCox : public Metric { } }; +/*! \brief Area Under PR Curve, for both classification and rank */ +struct EvalAucPR : public Metric { + // implementation of AUC-PR for weighted data + // translated from PRROC R Package + // see https://doi.org/10.1371/journal.pone.0092209 + + bst_float Eval(const std::vector &preds, const MetaInfo &info, + bool distributed) const override { + CHECK_NE(info.labels.size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.size(), info.labels.size()) + << "label size predict size not match"; + std::vector tgptr(2, 0); + tgptr[1] = static_cast(info.labels.size()); + const std::vector &gptr = + info.group_ptr.size() == 0 ? tgptr : info.group_ptr; + CHECK_EQ(gptr.back(), info.labels.size()) + << "EvalAucPR: group structure must match number of prediction"; + const bst_omp_uint ngroup = static_cast(gptr.size() - 1); + // sum statistics + double auc = 0.0; + int auc_error = 0, auc_gt_one = 0; + // each thread takes a local rec + std::vector> rec; + for (bst_omp_uint k = 0; k < ngroup; ++k) { + double total_pos = 0.0; + double total_neg = 0.0; + rec.clear(); + for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) { + total_pos += info.GetWeight(j) * info.labels[j]; + total_neg += info.GetWeight(j) * (1.0f - info.labels[j]); + rec.push_back(std::make_pair(preds[j], j)); + } + XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst); + // we need pos > 0 && neg > 0 + if (0.0 == total_pos || 0.0 == total_neg) { + auc_error = 1; + } + // calculate AUC + double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0; + for (size_t j = 0; j < rec.size(); ++j) { + tp += info.GetWeight(rec[j].second) * info.labels[rec[j].second]; + fp += info.GetWeight(rec[j].second) * (1.0f - info.labels[rec[j].second]); + if ((j < rec.size() - 1 && rec[j].first != rec[j + 1].first) || j == rec.size() - 1) { + if (tp == prevtp) { + h = 1.0; + a = 1.0; + b = 0.0; + } else { + h = (fp - prevfp) / (tp - prevtp); + a = 1.0 + h; + b = (prevfp - h * prevtp) / total_pos; + } + if (0.0 != b) { + auc += (tp / total_pos - prevtp / total_pos - + b / a * (std::log(a * tp / total_pos + b) - + std::log(a * prevtp / total_pos + b))) / a; + } else { + auc += (tp / total_pos - prevtp / total_pos) / a; + } + if (auc > 1.0) { + auc_gt_one = 1; + } + prevtp = tp; + prevfp = fp; + } + } + // sanity check + if (tp < 0 || prevtp < 0 || fp < 0 || prevfp < 0) { + CHECK(!auc_error) << "AUC-PR: error in calculation"; + } + } + CHECK(!auc_error) << "AUC-PR: the dataset only contains pos or neg samples"; + CHECK(!auc_gt_one) << "AUC-PR: AUC > 1.0"; + if (distributed) { + bst_float dat[2]; + dat[0] = static_cast(auc); + dat[1] = static_cast(ngroup); + // approximately estimate auc using mean + rabit::Allreduce(dat, 2); + return dat[0] / dat[1]; + } else { + return static_cast(auc) / ngroup; + } + } + const char *Name() const override { return "aucpr"; } +}; + + XGBOOST_REGISTER_METRIC(AMS, "ams") .describe("AMS metric for higgs.") .set_body([](const char* param) { return new EvalAMS(param); }); @@ -358,6 +446,10 @@ XGBOOST_REGISTER_METRIC(Auc, "auc") .describe("Area under curve for both classification and rank.") .set_body([](const char* param) { return new EvalAuc(); }); +XGBOOST_REGISTER_METRIC(AucPR, "aucpr") +.describe("Area under PR curve for both classification and rank.") +.set_body([](const char* param) { return new EvalAucPR(); }); + XGBOOST_REGISTER_METRIC(Precision, "pre") .describe("precision@k for rank.") .set_body([](const char* param) { return new EvalPrecision(param); }); @@ -375,3 +467,4 @@ XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik") .set_body([](const char* param) { return new EvalCox(); }); } // namespace metric } // namespace xgboost + diff --git a/tests/cpp/metric/test_rank_metric.cc b/tests/cpp/metric/test_rank_metric.cc index cfe99e82c..f0f7a0090 100644 --- a/tests/cpp/metric/test_rank_metric.cc +++ b/tests/cpp/metric/test_rank_metric.cc @@ -31,6 +31,27 @@ TEST(Metric, AUC) { EXPECT_ANY_THROW(GetMetricEval(metric, {0, 0}, {0, 0})); } +TEST(Metric, AUCPR) { + xgboost::Metric *metric = xgboost::Metric::Create("aucpr"); + ASSERT_STREQ(metric->Name(), "aucpr"); + EXPECT_NEAR(GetMetricEval(metric, {0, 0, 1, 1}, {0, 0, 1, 1}), 1, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, {0, 0, 1, 1}), + 0.5f, 0.001f); + EXPECT_NEAR( + GetMetricEval(metric, + {0.4f, 0.2f, 0.9f, 0.1f, 0.2f, 0.4f, 0.1f, 0.1f, 0.2f, 0.1}, + {0, 0, 0, 0, 0, 1, 0, 0, 1, 1}), + 0.2908445f, 0.001f); + EXPECT_NEAR(GetMetricEval( + metric, {0.87f, 0.31f, 0.40f, 0.42f, 0.25f, 0.66f, 0.95f, + 0.09f, 0.10f, 0.97f, 0.76f, 0.69f, 0.15f, 0.20f, + 0.30f, 0.14f, 0.07f, 0.58f, 0.61f, 0.08f}, + {0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1}), + 0.2769199f, 0.001f); + EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {})); + EXPECT_ANY_THROW(GetMetricEval(metric, {0, 0}, {0, 0})); +} + TEST(Metric, Precision) { // When the limit for precision is not given, it takes the limit at // std::numeric_limits::max(); hence all values are very small