Fix auc error in distributed mode (#4798)
Need more work for a complete fix. See #4663 .
This commit is contained in:
parent
733ed24dd9
commit
2aed0ae230
@ -191,7 +191,7 @@ struct EvalAuc : public Metric {
|
||||
sum_nneg += buf_neg;
|
||||
// check weird conditions
|
||||
if (sum_npos <= 0.0 || sum_nneg <= 0.0) {
|
||||
auc_error = 1;
|
||||
auc_error += 1;
|
||||
continue;
|
||||
}
|
||||
// this is the AUC
|
||||
@ -202,9 +202,9 @@ struct EvalAuc : public Metric {
|
||||
// In distributed mode, workers which only contains pos or neg samples
|
||||
// will be ignored when aggregate AUC.
|
||||
bst_float dat[2] = {0.0f, 0.0f};
|
||||
if (!auc_error) {
|
||||
if (auc_error < static_cast<int>(ngroup)) {
|
||||
dat[0] = static_cast<bst_float>(sum_auc);
|
||||
dat[1] = static_cast<bst_float>(ngroup);
|
||||
dat[1] = static_cast<bst_float>(static_cast<int>(ngroup) - auc_error);
|
||||
}
|
||||
if (distributed) {
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
@ -484,7 +484,8 @@ struct EvalAucPR : public Metric {
|
||||
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||
// we need pos > 0 && neg > 0
|
||||
if (0.0 == total_pos || 0.0 == total_neg) {
|
||||
auc_error = 1;
|
||||
auc_error += 1;
|
||||
continue;
|
||||
}
|
||||
// calculate AUC
|
||||
double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0;
|
||||
@ -523,9 +524,9 @@ struct EvalAucPR : public Metric {
|
||||
// In distributed mode, workers which only contains pos or neg samples
|
||||
// will be ignored when aggregate AUC-PR.
|
||||
bst_float dat[2] = {0.0f, 0.0f};
|
||||
if (!auc_error) {
|
||||
if (auc_error < static_cast<int>(ngroup)) {
|
||||
dat[0] = static_cast<bst_float>(sum_auc);
|
||||
dat[1] = static_cast<bst_float>(ngroup);
|
||||
dat[1] = static_cast<bst_float>(static_cast<int>(ngroup) - auc_error);
|
||||
}
|
||||
if (distributed) {
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user