rank_metric: add AUC-PR (#3172)
* rank_metric: add AUC-PR Implementation of the AUC-PR calculation for weighted data, proposed by Keilwagen, Grosse and Grau (https://doi.org/10.1371/journal.pone.0092209) * rank_metric: fix lint warnings * Implement tests for AUC-PR and fix implementation * add aucpr to documentation for other languages
This commit is contained in:
parent
8fb3388af2
commit
04221a7469
@ -34,6 +34,7 @@
|
|||||||
#' \item \code{rmse} Rooted mean square error
|
#' \item \code{rmse} Rooted mean square error
|
||||||
#' \item \code{logloss} negative log-likelihood function
|
#' \item \code{logloss} negative log-likelihood function
|
||||||
#' \item \code{auc} Area under curve
|
#' \item \code{auc} Area under curve
|
||||||
|
#' \item \code{aucpr} Area under PR curve
|
||||||
#' \item \code{merror} Exact matching error, used to evaluate multi-class classification
|
#' \item \code{merror} Exact matching error, used to evaluate multi-class classification
|
||||||
#' }
|
#' }
|
||||||
#' @param obj customized objective function. Returns gradient and second order
|
#' @param obj customized objective function. Returns gradient and second order
|
||||||
|
|||||||
@ -127,6 +127,7 @@
|
|||||||
#' Different threshold (e.g., 0.) could be specified as "error@0."
|
#' Different threshold (e.g., 0.) could be specified as "error@0."
|
||||||
#' \item \code{merror} Multiclass classification error rate. It is calculated as \code{(# wrong cases) / (# all cases)}.
|
#' \item \code{merror} Multiclass classification error rate. It is calculated as \code{(# wrong cases) / (# all cases)}.
|
||||||
#' \item \code{auc} Area under the curve. \url{http://en.wikipedia.org/wiki/Receiver_operating_characteristic#'Area_under_curve} for ranking evaluation.
|
#' \item \code{auc} Area under the curve. \url{http://en.wikipedia.org/wiki/Receiver_operating_characteristic#'Area_under_curve} for ranking evaluation.
|
||||||
|
#' \item \code{aucpr} Area under the PR curve. \url{https://en.wikipedia.org/wiki/Precision_and_recall} for ranking evaluation.
|
||||||
#' \item \code{ndcg} Normalized Discounted Cumulative Gain (for ranking task). \url{http://en.wikipedia.org/wiki/NDCG}
|
#' \item \code{ndcg} Normalized Discounted Cumulative Gain (for ranking task). \url{http://en.wikipedia.org/wiki/NDCG}
|
||||||
#' }
|
#' }
|
||||||
#'
|
#'
|
||||||
|
|||||||
@ -51,6 +51,7 @@ from each CV model. This parameter engages the \code{\link{cb.cv.predict}} callb
|
|||||||
\item \code{rmse} Rooted mean square error
|
\item \code{rmse} Rooted mean square error
|
||||||
\item \code{logloss} negative log-likelihood function
|
\item \code{logloss} negative log-likelihood function
|
||||||
\item \code{auc} Area under curve
|
\item \code{auc} Area under curve
|
||||||
|
\item \code{aucpr} Area under PR curve
|
||||||
\item \code{merror} Exact matching error, used to evaluate multi-class classification
|
\item \code{merror} Exact matching error, used to evaluate multi-class classification
|
||||||
}}
|
}}
|
||||||
|
|
||||||
|
|||||||
@ -186,6 +186,7 @@ The folloiwing is the list of built-in metrics for which Xgboost provides optimi
|
|||||||
Different threshold (e.g., 0.) could be specified as "error@0."
|
Different threshold (e.g., 0.) could be specified as "error@0."
|
||||||
\item \code{merror} Multiclass classification error rate. It is calculated as \code{(# wrong cases) / (# all cases)}.
|
\item \code{merror} Multiclass classification error rate. It is calculated as \code{(# wrong cases) / (# all cases)}.
|
||||||
\item \code{auc} Area under the curve. \url{http://en.wikipedia.org/wiki/Receiver_operating_characteristic#'Area_under_curve} for ranking evaluation.
|
\item \code{auc} Area under the curve. \url{http://en.wikipedia.org/wiki/Receiver_operating_characteristic#'Area_under_curve} for ranking evaluation.
|
||||||
|
\item \code{aucpr} Area under the PR curve. \url{https://en.wikipedia.org/wiki/Precision_and_recall} for ranking evaluation.
|
||||||
\item \code{ndcg} Normalized Discounted Cumulative Gain (for ranking task). \url{http://en.wikipedia.org/wiki/NDCG}
|
\item \code{ndcg} Normalized Discounted Cumulative Gain (for ranking task). \url{http://en.wikipedia.org/wiki/NDCG}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -45,7 +45,8 @@ trait LearningTaskParams extends Params {
|
|||||||
/**
|
/**
|
||||||
* evaluation metrics for validation data, a default metric will be assigned according to
|
* evaluation metrics for validation data, a default metric will be assigned according to
|
||||||
* objective(rmse for regression, and error for classification, mean average precision for
|
* objective(rmse for regression, and error for classification, mean average precision for
|
||||||
* ranking). options: rmse, mae, logloss, error, merror, mlogloss, auc, ndcg, map, gamma-deviance
|
* ranking). options: rmse, mae, logloss, error, merror, mlogloss, auc, aucpr, ndcg, map,
|
||||||
|
* gamma-deviance
|
||||||
*/
|
*/
|
||||||
val evalMetric = new Param[String](this, "eval_metric", "evaluation metrics for validation" +
|
val evalMetric = new Param[String](this, "eval_metric", "evaluation metrics for validation" +
|
||||||
" data, a default metric will be assigned according to objective (rmse for regression, and" +
|
" data, a default metric will be assigned according to objective (rmse for regression, and" +
|
||||||
@ -97,5 +98,5 @@ private[spark] object LearningTaskParams {
|
|||||||
"reg:gamma")
|
"reg:gamma")
|
||||||
|
|
||||||
val supportedEvalMetrics = HashSet("rmse", "mae", "logloss", "error", "merror", "mlogloss",
|
val supportedEvalMetrics = HashSet("rmse", "mae", "logloss", "error", "merror", "mlogloss",
|
||||||
"auc", "ndcg", "map", "gamma-deviance")
|
"auc", "aucpr", "ndcg", "map", "gamma-deviance")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -350,6 +350,94 @@ struct EvalCox : public Metric {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/*! \brief Area Under PR Curve, for both classification and rank */
|
||||||
|
struct EvalAucPR : public Metric {
|
||||||
|
// implementation of AUC-PR for weighted data
|
||||||
|
// translated from PRROC R Package
|
||||||
|
// see https://doi.org/10.1371/journal.pone.0092209
|
||||||
|
|
||||||
|
bst_float Eval(const std::vector<bst_float> &preds, const MetaInfo &info,
|
||||||
|
bool distributed) const override {
|
||||||
|
CHECK_NE(info.labels.size(), 0U) << "label set cannot be empty";
|
||||||
|
CHECK_EQ(preds.size(), info.labels.size())
|
||||||
|
<< "label size predict size not match";
|
||||||
|
std::vector<unsigned> tgptr(2, 0);
|
||||||
|
tgptr[1] = static_cast<unsigned>(info.labels.size());
|
||||||
|
const std::vector<unsigned> &gptr =
|
||||||
|
info.group_ptr.size() == 0 ? tgptr : info.group_ptr;
|
||||||
|
CHECK_EQ(gptr.back(), info.labels.size())
|
||||||
|
<< "EvalAucPR: group structure must match number of prediction";
|
||||||
|
const bst_omp_uint ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
|
||||||
|
// sum statistics
|
||||||
|
double auc = 0.0;
|
||||||
|
int auc_error = 0, auc_gt_one = 0;
|
||||||
|
// each thread takes a local rec
|
||||||
|
std::vector<std::pair<bst_float, unsigned>> rec;
|
||||||
|
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
||||||
|
double total_pos = 0.0;
|
||||||
|
double total_neg = 0.0;
|
||||||
|
rec.clear();
|
||||||
|
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
|
||||||
|
total_pos += info.GetWeight(j) * info.labels[j];
|
||||||
|
total_neg += info.GetWeight(j) * (1.0f - info.labels[j]);
|
||||||
|
rec.push_back(std::make_pair(preds[j], j));
|
||||||
|
}
|
||||||
|
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||||
|
// we need pos > 0 && neg > 0
|
||||||
|
if (0.0 == total_pos || 0.0 == total_neg) {
|
||||||
|
auc_error = 1;
|
||||||
|
}
|
||||||
|
// calculate AUC
|
||||||
|
double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0;
|
||||||
|
for (size_t j = 0; j < rec.size(); ++j) {
|
||||||
|
tp += info.GetWeight(rec[j].second) * info.labels[rec[j].second];
|
||||||
|
fp += info.GetWeight(rec[j].second) * (1.0f - info.labels[rec[j].second]);
|
||||||
|
if ((j < rec.size() - 1 && rec[j].first != rec[j + 1].first) || j == rec.size() - 1) {
|
||||||
|
if (tp == prevtp) {
|
||||||
|
h = 1.0;
|
||||||
|
a = 1.0;
|
||||||
|
b = 0.0;
|
||||||
|
} else {
|
||||||
|
h = (fp - prevfp) / (tp - prevtp);
|
||||||
|
a = 1.0 + h;
|
||||||
|
b = (prevfp - h * prevtp) / total_pos;
|
||||||
|
}
|
||||||
|
if (0.0 != b) {
|
||||||
|
auc += (tp / total_pos - prevtp / total_pos -
|
||||||
|
b / a * (std::log(a * tp / total_pos + b) -
|
||||||
|
std::log(a * prevtp / total_pos + b))) / a;
|
||||||
|
} else {
|
||||||
|
auc += (tp / total_pos - prevtp / total_pos) / a;
|
||||||
|
}
|
||||||
|
if (auc > 1.0) {
|
||||||
|
auc_gt_one = 1;
|
||||||
|
}
|
||||||
|
prevtp = tp;
|
||||||
|
prevfp = fp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// sanity check
|
||||||
|
if (tp < 0 || prevtp < 0 || fp < 0 || prevfp < 0) {
|
||||||
|
CHECK(!auc_error) << "AUC-PR: error in calculation";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CHECK(!auc_error) << "AUC-PR: the dataset only contains pos or neg samples";
|
||||||
|
CHECK(!auc_gt_one) << "AUC-PR: AUC > 1.0";
|
||||||
|
if (distributed) {
|
||||||
|
bst_float dat[2];
|
||||||
|
dat[0] = static_cast<bst_float>(auc);
|
||||||
|
dat[1] = static_cast<bst_float>(ngroup);
|
||||||
|
// approximately estimate auc using mean
|
||||||
|
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||||
|
return dat[0] / dat[1];
|
||||||
|
} else {
|
||||||
|
return static_cast<bst_float>(auc) / ngroup;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const char *Name() const override { return "aucpr"; }
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
XGBOOST_REGISTER_METRIC(AMS, "ams")
|
XGBOOST_REGISTER_METRIC(AMS, "ams")
|
||||||
.describe("AMS metric for higgs.")
|
.describe("AMS metric for higgs.")
|
||||||
.set_body([](const char* param) { return new EvalAMS(param); });
|
.set_body([](const char* param) { return new EvalAMS(param); });
|
||||||
@ -358,6 +446,10 @@ XGBOOST_REGISTER_METRIC(Auc, "auc")
|
|||||||
.describe("Area under curve for both classification and rank.")
|
.describe("Area under curve for both classification and rank.")
|
||||||
.set_body([](const char* param) { return new EvalAuc(); });
|
.set_body([](const char* param) { return new EvalAuc(); });
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_METRIC(AucPR, "aucpr")
|
||||||
|
.describe("Area under PR curve for both classification and rank.")
|
||||||
|
.set_body([](const char* param) { return new EvalAucPR(); });
|
||||||
|
|
||||||
XGBOOST_REGISTER_METRIC(Precision, "pre")
|
XGBOOST_REGISTER_METRIC(Precision, "pre")
|
||||||
.describe("precision@k for rank.")
|
.describe("precision@k for rank.")
|
||||||
.set_body([](const char* param) { return new EvalPrecision(param); });
|
.set_body([](const char* param) { return new EvalPrecision(param); });
|
||||||
@ -375,3 +467,4 @@ XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik")
|
|||||||
.set_body([](const char* param) { return new EvalCox(); });
|
.set_body([](const char* param) { return new EvalCox(); });
|
||||||
} // namespace metric
|
} // namespace metric
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
|
|||||||
@ -31,6 +31,27 @@ TEST(Metric, AUC) {
|
|||||||
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 0}, {0, 0}));
|
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 0}, {0, 0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Metric, AUCPR) {
|
||||||
|
xgboost::Metric *metric = xgboost::Metric::Create("aucpr");
|
||||||
|
ASSERT_STREQ(metric->Name(), "aucpr");
|
||||||
|
EXPECT_NEAR(GetMetricEval(metric, {0, 0, 1, 1}, {0, 0, 1, 1}), 1, 1e-10);
|
||||||
|
EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, {0, 0, 1, 1}),
|
||||||
|
0.5f, 0.001f);
|
||||||
|
EXPECT_NEAR(
|
||||||
|
GetMetricEval(metric,
|
||||||
|
{0.4f, 0.2f, 0.9f, 0.1f, 0.2f, 0.4f, 0.1f, 0.1f, 0.2f, 0.1},
|
||||||
|
{0, 0, 0, 0, 0, 1, 0, 0, 1, 1}),
|
||||||
|
0.2908445f, 0.001f);
|
||||||
|
EXPECT_NEAR(GetMetricEval(
|
||||||
|
metric, {0.87f, 0.31f, 0.40f, 0.42f, 0.25f, 0.66f, 0.95f,
|
||||||
|
0.09f, 0.10f, 0.97f, 0.76f, 0.69f, 0.15f, 0.20f,
|
||||||
|
0.30f, 0.14f, 0.07f, 0.58f, 0.61f, 0.08f},
|
||||||
|
{0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1}),
|
||||||
|
0.2769199f, 0.001f);
|
||||||
|
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {}));
|
||||||
|
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 0}, {0, 0}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(Metric, Precision) {
|
TEST(Metric, Precision) {
|
||||||
// When the limit for precision is not given, it takes the limit at
|
// When the limit for precision is not given, it takes the limit at
|
||||||
// std::numeric_limits<unsigned>::max(); hence all values are very small
|
// std::numeric_limits<unsigned>::max(); hence all values are very small
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user