fix auc error in distributed mode caused by unbalanced dataset (#4645)
This commit is contained in:
parent
30204b50fe
commit
cd1526d3b1
@ -197,18 +197,21 @@ struct EvalAuc : public Metric {
|
|||||||
// this is the AUC
|
// this is the AUC
|
||||||
sum_auc += sum_pospair / (sum_npos * sum_nneg);
|
sum_auc += sum_pospair / (sum_npos * sum_nneg);
|
||||||
}
|
}
|
||||||
CHECK(!auc_error)
|
|
||||||
<< "AUC: the dataset only contains pos or neg samples";
|
// Report average AUC across all groups
|
||||||
/* Report average AUC across all groups */
|
// In distributed mode, workers which only contains pos or neg samples
|
||||||
if (distributed) {
|
// will be ignored when aggregate AUC.
|
||||||
bst_float dat[2];
|
bst_float dat[2] = {0.0f, 0.0f};
|
||||||
|
if (!auc_error) {
|
||||||
dat[0] = static_cast<bst_float>(sum_auc);
|
dat[0] = static_cast<bst_float>(sum_auc);
|
||||||
dat[1] = static_cast<bst_float>(ngroup);
|
dat[1] = static_cast<bst_float>(ngroup);
|
||||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
|
||||||
return dat[0] / dat[1];
|
|
||||||
} else {
|
|
||||||
return static_cast<bst_float>(sum_auc) / ngroup;
|
|
||||||
}
|
}
|
||||||
|
if (distributed) {
|
||||||
|
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||||
|
}
|
||||||
|
CHECK_GT(dat[1], 0.0f)
|
||||||
|
<< "AUC: the dataset only contains pos or neg samples";
|
||||||
|
return dat[0] / dat[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@ -515,19 +518,22 @@ struct EvalAucPR : public Metric {
|
|||||||
CHECK(!auc_error) << "AUC-PR: error in calculation";
|
CHECK(!auc_error) << "AUC-PR: error in calculation";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
CHECK(!auc_error) << "AUC-PR: the dataset only contains pos or neg samples";
|
|
||||||
/* Report average AUC across all groups */
|
// Report average AUC-PR across all groups
|
||||||
if (distributed) {
|
// In distributed mode, workers which only contains pos or neg samples
|
||||||
bst_float dat[2];
|
// will be ignored when aggregate AUC-PR.
|
||||||
|
bst_float dat[2] = {0.0f, 0.0f};
|
||||||
|
if (!auc_error) {
|
||||||
dat[0] = static_cast<bst_float>(sum_auc);
|
dat[0] = static_cast<bst_float>(sum_auc);
|
||||||
dat[1] = static_cast<bst_float>(ngroup);
|
dat[1] = static_cast<bst_float>(ngroup);
|
||||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
|
||||||
CHECK_LE(dat[0], dat[1]) << "AUC-PR: AUC > 1.0";
|
|
||||||
return dat[0] / dat[1];
|
|
||||||
} else {
|
|
||||||
CHECK_LE(sum_auc, static_cast<double>(ngroup)) << "AUC-PR: AUC > 1.0";
|
|
||||||
return static_cast<bst_float>(sum_auc) / ngroup;
|
|
||||||
}
|
}
|
||||||
|
if (distributed) {
|
||||||
|
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||||
|
}
|
||||||
|
CHECK_GT(dat[1], 0.0f)
|
||||||
|
<< "AUC-PR: the dataset only contains pos or neg samples";
|
||||||
|
CHECK_LE(dat[0], dat[1]) << "AUC-PR: AUC > 1.0";
|
||||||
|
return dat[0] / dat[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user