From cd1526d3b1155432fca82e8b4895bc0b827b4d29 Mon Sep 17 00:00:00 2001 From: Xu Xiao Date: Mon, 8 Jul 2019 16:01:52 +0800 Subject: [PATCH] fix auc error in distributed mode caused by unbalanced dataset (#4645) --- src/metric/rank_metric.cc | 44 ++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index bb1b053b7..f29975e95 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -197,18 +197,21 @@ struct EvalAuc : public Metric { // this is the AUC sum_auc += sum_pospair / (sum_npos * sum_nneg); } - CHECK(!auc_error) - << "AUC: the dataset only contains pos or neg samples"; - /* Report average AUC across all groups */ - if (distributed) { - bst_float dat[2]; + + // Report average AUC across all groups + // In distributed mode, workers which only contains pos or neg samples + // will be ignored when aggregate AUC. + bst_float dat[2] = {0.0f, 0.0f}; + if (!auc_error) { dat[0] = static_cast(sum_auc); dat[1] = static_cast(ngroup); - rabit::Allreduce(dat, 2); - return dat[0] / dat[1]; - } else { - return static_cast(sum_auc) / ngroup; } + if (distributed) { + rabit::Allreduce(dat, 2); + } + CHECK_GT(dat[1], 0.0f) + << "AUC: the dataset only contains pos or neg samples"; + return dat[0] / dat[1]; } public: @@ -515,19 +518,22 @@ struct EvalAucPR : public Metric { CHECK(!auc_error) << "AUC-PR: error in calculation"; } } - CHECK(!auc_error) << "AUC-PR: the dataset only contains pos or neg samples"; - /* Report average AUC across all groups */ - if (distributed) { - bst_float dat[2]; + + // Report average AUC-PR across all groups + // In distributed mode, workers which only contains pos or neg samples + // will be ignored when aggregate AUC-PR. + bst_float dat[2] = {0.0f, 0.0f}; + if (!auc_error) { dat[0] = static_cast(sum_auc); dat[1] = static_cast(ngroup); - rabit::Allreduce(dat, 2); - CHECK_LE(dat[0], dat[1]) << "AUC-PR: AUC > 1.0"; - return dat[0] / dat[1]; - } else { - CHECK_LE(sum_auc, static_cast(ngroup)) << "AUC-PR: AUC > 1.0"; - return static_cast(sum_auc) / ngroup; } + if (distributed) { + rabit::Allreduce(dat, 2); + } + CHECK_GT(dat[1], 0.0f) + << "AUC-PR: the dataset only contains pos or neg samples"; + CHECK_LE(dat[0], dat[1]) << "AUC-PR: AUC > 1.0"; + return dat[0] / dat[1]; } public: