From d05754f558a50d68322e36f14d1c4f11d2f360d9 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 28 Oct 2021 05:03:52 +0800 Subject: [PATCH] Avoid OMP reduction in AUC. (#7362) --- src/metric/auc.cc | 52 +++++++++++++++++++++++------------------------ 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/src/metric/auc.cc b/src/metric/auc.cc index 63315150f..d58aefca3 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -203,42 +204,39 @@ std::pair RankingAUC(std::vector const &predts, MetaInfo const &info, int32_t n_threads) { CHECK_GE(info.group_ptr_.size(), 2); uint32_t n_groups = info.group_ptr_.size() - 1; - float sum_auc = 0; auto s_predts = common::Span{predts}; auto s_labels = info.labels_.ConstHostSpan(); auto s_weights = info.weights_.ConstHostSpan(); std::atomic invalid_groups{0}; - dmlc::OMPException omp_handler; -#pragma omp parallel for reduction(+:sum_auc) num_threads(n_threads) - for (omp_ulong g = 1; g < info.group_ptr_.size(); ++g) { - omp_handler.Run([&]() { - size_t cnt = info.group_ptr_[g] - info.group_ptr_[g - 1]; - float w = s_weights.empty() ? 1.0f : s_weights[g - 1]; - auto g_predts = s_predts.subspan(info.group_ptr_[g - 1], cnt); - auto g_labels = s_labels.subspan(info.group_ptr_[g - 1], cnt); - float auc; - if (is_roc && g_labels.size() < 3) { - // With 2 documents, there's only 1 comparison can be made. So either - // TP or FP will be zero. + std::vector auc_tloc(n_threads, 0); + common::ParallelFor(n_groups, n_threads, [&](size_t g) { + g += 1; // indexing needs to start from 1 + size_t cnt = info.group_ptr_[g] - info.group_ptr_[g - 1]; + float w = s_weights.empty() ? 1.0f : s_weights[g - 1]; + auto g_predts = s_predts.subspan(info.group_ptr_[g - 1], cnt); + auto g_labels = s_labels.subspan(info.group_ptr_[g - 1], cnt); + float auc; + if (is_roc && g_labels.size() < 3) { + // With 2 documents, there's only 1 comparison can be made. So either + // TP or FP will be zero. + invalid_groups++; + auc = 0; + } else { + if (is_roc) { + auc = GroupRankingROC(g_predts, g_labels, w); + } else { + auc = std::get<2>(BinaryPRAUC(g_predts, g_labels, OptionalWeights{w})); + } + if (std::isnan(auc)) { invalid_groups++; auc = 0; - } else { - if (is_roc) { - auc = GroupRankingROC(g_predts, g_labels, w); - } else { - auc = std::get<2>(BinaryPRAUC(g_predts, g_labels, OptionalWeights{w})); - } - if (std::isnan(auc)) { - invalid_groups++; - auc = 0; - } } - sum_auc += auc; - }); - } - omp_handler.Rethrow(); + } + auc_tloc[omp_get_thread_num()] += auc; + }); + float sum_auc = std::accumulate(auc_tloc.cbegin(), auc_tloc.cend(), 0.0); return std::make_pair(sum_auc, n_groups - invalid_groups); }