Avoid OMP reduction in AUC. (#7362)

This commit is contained in:
Jiaming Yuan 2021-10-28 05:03:52 +08:00 committed by GitHub
parent ac9bfaa4f2
commit d05754f558
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -7,6 +7,7 @@
#include <functional> #include <functional>
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <numeric>
#include <utility> #include <utility>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
@ -203,42 +204,39 @@ std::pair<float, uint32_t> RankingAUC(std::vector<float> const &predts,
MetaInfo const &info, int32_t n_threads) { MetaInfo const &info, int32_t n_threads) {
CHECK_GE(info.group_ptr_.size(), 2); CHECK_GE(info.group_ptr_.size(), 2);
uint32_t n_groups = info.group_ptr_.size() - 1; uint32_t n_groups = info.group_ptr_.size() - 1;
float sum_auc = 0;
auto s_predts = common::Span<float const>{predts}; auto s_predts = common::Span<float const>{predts};
auto s_labels = info.labels_.ConstHostSpan(); auto s_labels = info.labels_.ConstHostSpan();
auto s_weights = info.weights_.ConstHostSpan(); auto s_weights = info.weights_.ConstHostSpan();
std::atomic<uint32_t> invalid_groups{0}; std::atomic<uint32_t> invalid_groups{0};
dmlc::OMPException omp_handler;
#pragma omp parallel for reduction(+:sum_auc) num_threads(n_threads) std::vector<double> auc_tloc(n_threads, 0);
for (omp_ulong g = 1; g < info.group_ptr_.size(); ++g) { common::ParallelFor(n_groups, n_threads, [&](size_t g) {
omp_handler.Run([&]() { g += 1; // indexing needs to start from 1
size_t cnt = info.group_ptr_[g] - info.group_ptr_[g - 1]; size_t cnt = info.group_ptr_[g] - info.group_ptr_[g - 1];
float w = s_weights.empty() ? 1.0f : s_weights[g - 1]; float w = s_weights.empty() ? 1.0f : s_weights[g - 1];
auto g_predts = s_predts.subspan(info.group_ptr_[g - 1], cnt); auto g_predts = s_predts.subspan(info.group_ptr_[g - 1], cnt);
auto g_labels = s_labels.subspan(info.group_ptr_[g - 1], cnt); auto g_labels = s_labels.subspan(info.group_ptr_[g - 1], cnt);
float auc; float auc;
if (is_roc && g_labels.size() < 3) { if (is_roc && g_labels.size() < 3) {
// With 2 documents, there's only 1 comparison can be made. So either // With 2 documents, there's only 1 comparison can be made. So either
// TP or FP will be zero. // TP or FP will be zero.
invalid_groups++;
auc = 0;
} else {
if (is_roc) {
auc = GroupRankingROC(g_predts, g_labels, w);
} else {
auc = std::get<2>(BinaryPRAUC(g_predts, g_labels, OptionalWeights{w}));
}
if (std::isnan(auc)) {
invalid_groups++; invalid_groups++;
auc = 0; auc = 0;
} else {
if (is_roc) {
auc = GroupRankingROC(g_predts, g_labels, w);
} else {
auc = std::get<2>(BinaryPRAUC(g_predts, g_labels, OptionalWeights{w}));
}
if (std::isnan(auc)) {
invalid_groups++;
auc = 0;
}
} }
sum_auc += auc; }
}); auc_tloc[omp_get_thread_num()] += auc;
} });
omp_handler.Rethrow(); float sum_auc = std::accumulate(auc_tloc.cbegin(), auc_tloc.cend(), 0.0);
return std::make_pair(sum_auc, n_groups - invalid_groups); return std::make_pair(sum_auc, n_groups - invalid_groups);
} }