Be more lenient on floating point error for AUC. (#10264)

This commit is contained in:
Jiaming Yuan 2024-05-11 08:48:11 +08:00 committed by GitHub
parent f588252481
commit 5de57435c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -177,7 +177,7 @@ double GroupRankingROC(Context const* ctx, common::Span<float const> predts,
if (sum_w != 0) { if (sum_w != 0) {
auc /= sum_w; auc /= sum_w;
} }
CHECK_LE(auc, 1.0f); CHECK_LE(auc, 1.0 + kRtEps);
return auc; return auc;
} }
@ -290,8 +290,8 @@ class EvalAUC : public MetricNoCache {
auc = collective::GlobalRatio(ctx_, info, auc, static_cast<double>(valid_groups)); auc = collective::GlobalRatio(ctx_, info, auc, static_cast<double>(valid_groups));
if (!std::isnan(auc)) { if (!std::isnan(auc)) {
CHECK_LE(auc, 1) << "Total AUC across groups: " << auc * valid_groups CHECK_LE(auc, 1.0 + kRtEps) << "Total AUC across groups: " << auc * valid_groups
<< ", valid groups: " << valid_groups; << ", valid groups: " << valid_groups;
} }
} else if (meta[0] != meta[1] && meta[1] % meta[0] == 0) { } else if (meta[0] != meta[1] && meta[1] % meta[0] == 0) {
/** /**
@ -311,7 +311,8 @@ class EvalAUC : public MetricNoCache {
} }
auc = collective::GlobalRatio(ctx_, info, auc, fp * tp); auc = collective::GlobalRatio(ctx_, info, auc, fp * tp);
if (!std::isnan(auc)) { if (!std::isnan(auc)) {
CHECK_LE(auc, 1.0); CHECK_LE(auc, 1.0 + kRtEps);
auc = std::min(auc, 1.0);
} }
} }
if (std::isnan(auc)) { if (std::isnan(auc)) {