Fix auc error in distributed mode (#4798)

Need more work for a complete fix.  See #4663 .
This commit is contained in:
TinkleG 2019-09-01 14:54:14 +08:00 committed by Jiaming Yuan
parent 733ed24dd9
commit 2aed0ae230

View File

@ -191,7 +191,7 @@ struct EvalAuc : public Metric {
sum_nneg += buf_neg; sum_nneg += buf_neg;
// check weird conditions // check weird conditions
if (sum_npos <= 0.0 || sum_nneg <= 0.0) { if (sum_npos <= 0.0 || sum_nneg <= 0.0) {
auc_error = 1; auc_error += 1;
continue; continue;
} }
// this is the AUC // this is the AUC
@ -202,9 +202,9 @@ struct EvalAuc : public Metric {
// In distributed mode, workers which only contains pos or neg samples // In distributed mode, workers which only contains pos or neg samples
// will be ignored when aggregate AUC. // will be ignored when aggregate AUC.
bst_float dat[2] = {0.0f, 0.0f}; bst_float dat[2] = {0.0f, 0.0f};
if (!auc_error) { if (auc_error < static_cast<int>(ngroup)) {
dat[0] = static_cast<bst_float>(sum_auc); dat[0] = static_cast<bst_float>(sum_auc);
dat[1] = static_cast<bst_float>(ngroup); dat[1] = static_cast<bst_float>(static_cast<int>(ngroup) - auc_error);
} }
if (distributed) { if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2); rabit::Allreduce<rabit::op::Sum>(dat, 2);
@ -484,7 +484,8 @@ struct EvalAucPR : public Metric {
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst); XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
// we need pos > 0 && neg > 0 // we need pos > 0 && neg > 0
if (0.0 == total_pos || 0.0 == total_neg) { if (0.0 == total_pos || 0.0 == total_neg) {
auc_error = 1; auc_error += 1;
continue;
} }
// calculate AUC // calculate AUC
double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0; double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0;
@ -523,9 +524,9 @@ struct EvalAucPR : public Metric {
// In distributed mode, workers which only contains pos or neg samples // In distributed mode, workers which only contains pos or neg samples
// will be ignored when aggregate AUC-PR. // will be ignored when aggregate AUC-PR.
bst_float dat[2] = {0.0f, 0.0f}; bst_float dat[2] = {0.0f, 0.0f};
if (!auc_error) { if (auc_error < static_cast<int>(ngroup)) {
dat[0] = static_cast<bst_float>(sum_auc); dat[0] = static_cast<bst_float>(sum_auc);
dat[1] = static_cast<bst_float>(ngroup); dat[1] = static_cast<bst_float>(static_cast<int>(ngroup) - auc_error);
} }
if (distributed) { if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2); rabit::Allreduce<rabit::op::Sum>(dat, 2);