fix auc error in distributed mode caused by unbalanced dataset (#4645)

This commit is contained in:
Xu Xiao 2019-07-08 16:01:52 +08:00 committed by Jiaming Yuan
parent 30204b50fe
commit cd1526d3b1

View File

@ -197,18 +197,21 @@ struct EvalAuc : public Metric {
// this is the AUC // this is the AUC
sum_auc += sum_pospair / (sum_npos * sum_nneg); sum_auc += sum_pospair / (sum_npos * sum_nneg);
} }
CHECK(!auc_error)
<< "AUC: the dataset only contains pos or neg samples"; // Report average AUC across all groups
/* Report average AUC across all groups */ // In distributed mode, workers which only contains pos or neg samples
if (distributed) { // will be ignored when aggregate AUC.
bst_float dat[2]; bst_float dat[2] = {0.0f, 0.0f};
if (!auc_error) {
dat[0] = static_cast<bst_float>(sum_auc); dat[0] = static_cast<bst_float>(sum_auc);
dat[1] = static_cast<bst_float>(ngroup); dat[1] = static_cast<bst_float>(ngroup);
rabit::Allreduce<rabit::op::Sum>(dat, 2);
return dat[0] / dat[1];
} else {
return static_cast<bst_float>(sum_auc) / ngroup;
} }
if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2);
}
CHECK_GT(dat[1], 0.0f)
<< "AUC: the dataset only contains pos or neg samples";
return dat[0] / dat[1];
} }
public: public:
@ -515,19 +518,22 @@ struct EvalAucPR : public Metric {
CHECK(!auc_error) << "AUC-PR: error in calculation"; CHECK(!auc_error) << "AUC-PR: error in calculation";
} }
} }
CHECK(!auc_error) << "AUC-PR: the dataset only contains pos or neg samples";
/* Report average AUC across all groups */ // Report average AUC-PR across all groups
if (distributed) { // In distributed mode, workers which only contains pos or neg samples
bst_float dat[2]; // will be ignored when aggregate AUC-PR.
bst_float dat[2] = {0.0f, 0.0f};
if (!auc_error) {
dat[0] = static_cast<bst_float>(sum_auc); dat[0] = static_cast<bst_float>(sum_auc);
dat[1] = static_cast<bst_float>(ngroup); dat[1] = static_cast<bst_float>(ngroup);
rabit::Allreduce<rabit::op::Sum>(dat, 2);
CHECK_LE(dat[0], dat[1]) << "AUC-PR: AUC > 1.0";
return dat[0] / dat[1];
} else {
CHECK_LE(sum_auc, static_cast<double>(ngroup)) << "AUC-PR: AUC > 1.0";
return static_cast<bst_float>(sum_auc) / ngroup;
} }
if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2);
}
CHECK_GT(dat[1], 0.0f)
<< "AUC-PR: the dataset only contains pos or neg samples";
CHECK_LE(dat[0], dat[1]) << "AUC-PR: AUC > 1.0";
return dat[0] / dat[1];
} }
public: public: