fix auc error in distributed mode caused by unbalanced dataset (#4645)
This commit is contained in:
parent
30204b50fe
commit
cd1526d3b1
@ -197,18 +197,21 @@ struct EvalAuc : public Metric {
|
||||
// this is the AUC
|
||||
sum_auc += sum_pospair / (sum_npos * sum_nneg);
|
||||
}
|
||||
CHECK(!auc_error)
|
||||
<< "AUC: the dataset only contains pos or neg samples";
|
||||
/* Report average AUC across all groups */
|
||||
if (distributed) {
|
||||
bst_float dat[2];
|
||||
|
||||
// Report average AUC across all groups
|
||||
// In distributed mode, workers which only contains pos or neg samples
|
||||
// will be ignored when aggregate AUC.
|
||||
bst_float dat[2] = {0.0f, 0.0f};
|
||||
if (!auc_error) {
|
||||
dat[0] = static_cast<bst_float>(sum_auc);
|
||||
dat[1] = static_cast<bst_float>(ngroup);
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
return dat[0] / dat[1];
|
||||
} else {
|
||||
return static_cast<bst_float>(sum_auc) / ngroup;
|
||||
}
|
||||
if (distributed) {
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
}
|
||||
CHECK_GT(dat[1], 0.0f)
|
||||
<< "AUC: the dataset only contains pos or neg samples";
|
||||
return dat[0] / dat[1];
|
||||
}
|
||||
|
||||
public:
|
||||
@ -515,19 +518,22 @@ struct EvalAucPR : public Metric {
|
||||
CHECK(!auc_error) << "AUC-PR: error in calculation";
|
||||
}
|
||||
}
|
||||
CHECK(!auc_error) << "AUC-PR: the dataset only contains pos or neg samples";
|
||||
/* Report average AUC across all groups */
|
||||
if (distributed) {
|
||||
bst_float dat[2];
|
||||
|
||||
// Report average AUC-PR across all groups
|
||||
// In distributed mode, workers which only contains pos or neg samples
|
||||
// will be ignored when aggregate AUC-PR.
|
||||
bst_float dat[2] = {0.0f, 0.0f};
|
||||
if (!auc_error) {
|
||||
dat[0] = static_cast<bst_float>(sum_auc);
|
||||
dat[1] = static_cast<bst_float>(ngroup);
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
CHECK_LE(dat[0], dat[1]) << "AUC-PR: AUC > 1.0";
|
||||
return dat[0] / dat[1];
|
||||
} else {
|
||||
CHECK_LE(sum_auc, static_cast<double>(ngroup)) << "AUC-PR: AUC > 1.0";
|
||||
return static_cast<bst_float>(sum_auc) / ngroup;
|
||||
}
|
||||
if (distributed) {
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
}
|
||||
CHECK_GT(dat[1], 0.0f)
|
||||
<< "AUC-PR: the dataset only contains pos or neg samples";
|
||||
CHECK_LE(dat[0], dat[1]) << "AUC-PR: AUC > 1.0";
|
||||
return dat[0] / dat[1];
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user