fix auc error in distributed mode caused by unbalanced dataset (#4645)

This commit is contained in:
Xu Xiao 2019-07-08 16:01:52 +08:00 committed by Jiaming Yuan
parent 30204b50fe
commit cd1526d3b1

View File

@ -197,18 +197,21 @@ struct EvalAuc : public Metric {
// this is the AUC
sum_auc += sum_pospair / (sum_npos * sum_nneg);
}
CHECK(!auc_error)
<< "AUC: the dataset only contains pos or neg samples";
/* Report average AUC across all groups */
if (distributed) {
bst_float dat[2];
// Report average AUC across all groups
// In distributed mode, workers which only contains pos or neg samples
// will be ignored when aggregate AUC.
bst_float dat[2] = {0.0f, 0.0f};
if (!auc_error) {
dat[0] = static_cast<bst_float>(sum_auc);
dat[1] = static_cast<bst_float>(ngroup);
rabit::Allreduce<rabit::op::Sum>(dat, 2);
return dat[0] / dat[1];
} else {
return static_cast<bst_float>(sum_auc) / ngroup;
}
if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2);
}
CHECK_GT(dat[1], 0.0f)
<< "AUC: the dataset only contains pos or neg samples";
return dat[0] / dat[1];
}
public:
@ -515,19 +518,22 @@ struct EvalAucPR : public Metric {
CHECK(!auc_error) << "AUC-PR: error in calculation";
}
}
CHECK(!auc_error) << "AUC-PR: the dataset only contains pos or neg samples";
/* Report average AUC across all groups */
if (distributed) {
bst_float dat[2];
// Report average AUC-PR across all groups
// In distributed mode, workers which only contains pos or neg samples
// will be ignored when aggregate AUC-PR.
bst_float dat[2] = {0.0f, 0.0f};
if (!auc_error) {
dat[0] = static_cast<bst_float>(sum_auc);
dat[1] = static_cast<bst_float>(ngroup);
rabit::Allreduce<rabit::op::Sum>(dat, 2);
CHECK_LE(dat[0], dat[1]) << "AUC-PR: AUC > 1.0";
return dat[0] / dat[1];
} else {
CHECK_LE(sum_auc, static_cast<double>(ngroup)) << "AUC-PR: AUC > 1.0";
return static_cast<bst_float>(sum_auc) / ngroup;
}
if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2);
}
CHECK_GT(dat[1], 0.0f)
<< "AUC-PR: the dataset only contains pos or neg samples";
CHECK_LE(dat[0], dat[1]) << "AUC-PR: AUC > 1.0";
return dat[0] / dat[1];
}
public: