diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index f29975e95..c8c7edd82 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -191,7 +191,7 @@ struct EvalAuc : public Metric { sum_nneg += buf_neg; // check weird conditions if (sum_npos <= 0.0 || sum_nneg <= 0.0) { - auc_error = 1; + auc_error += 1; continue; } // this is the AUC @@ -202,9 +202,9 @@ struct EvalAuc : public Metric { // In distributed mode, workers which only contains pos or neg samples // will be ignored when aggregate AUC. bst_float dat[2] = {0.0f, 0.0f}; - if (!auc_error) { + if (auc_error < static_cast(ngroup)) { dat[0] = static_cast(sum_auc); - dat[1] = static_cast(ngroup); + dat[1] = static_cast(static_cast(ngroup) - auc_error); } if (distributed) { rabit::Allreduce(dat, 2); @@ -484,7 +484,8 @@ struct EvalAucPR : public Metric { XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst); // we need pos > 0 && neg > 0 if (0.0 == total_pos || 0.0 == total_neg) { - auc_error = 1; + auc_error += 1; + continue; } // calculate AUC double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0; @@ -523,9 +524,9 @@ struct EvalAucPR : public Metric { // In distributed mode, workers which only contains pos or neg samples // will be ignored when aggregate AUC-PR. bst_float dat[2] = {0.0f, 0.0f}; - if (!auc_error) { + if (auc_error < static_cast(ngroup)) { dat[0] = static_cast(sum_auc); - dat[1] = static_cast(ngroup); + dat[1] = static_cast(static_cast(ngroup) - auc_error); } if (distributed) { rabit::Allreduce(dat, 2);