From 5de57435c7ae680029ef68565c4dbb4b3a761856 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 11 May 2024 08:48:11 +0800 Subject: [PATCH] Be more lenient on floating point error for AUC. (#10264) --- src/metric/auc.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/metric/auc.cc b/src/metric/auc.cc index 212a3a027..189c2b8e7 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -177,7 +177,7 @@ double GroupRankingROC(Context const* ctx, common::Span predts, if (sum_w != 0) { auc /= sum_w; } - CHECK_LE(auc, 1.0f); + CHECK_LE(auc, 1.0 + kRtEps); return auc; } @@ -290,8 +290,8 @@ class EvalAUC : public MetricNoCache { auc = collective::GlobalRatio(ctx_, info, auc, static_cast(valid_groups)); if (!std::isnan(auc)) { - CHECK_LE(auc, 1) << "Total AUC across groups: " << auc * valid_groups - << ", valid groups: " << valid_groups; + CHECK_LE(auc, 1.0 + kRtEps) << "Total AUC across groups: " << auc * valid_groups + << ", valid groups: " << valid_groups; } } else if (meta[0] != meta[1] && meta[1] % meta[0] == 0) { /** @@ -311,7 +311,8 @@ class EvalAUC : public MetricNoCache { } auc = collective::GlobalRatio(ctx_, info, auc, fp * tp); if (!std::isnan(auc)) { - CHECK_LE(auc, 1.0); + CHECK_LE(auc, 1.0 + kRtEps); + auc = std::min(auc, 1.0); } } if (std::isnan(auc)) {