Avoid OMP reduction in AUC. (#7362)
This commit is contained in:
parent
ac9bfaa4f2
commit
d05754f558
@ -7,6 +7,7 @@
|
|||||||
#include <functional>
|
#include <functional>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <numeric>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -203,42 +204,39 @@ std::pair<float, uint32_t> RankingAUC(std::vector<float> const &predts,
|
|||||||
MetaInfo const &info, int32_t n_threads) {
|
MetaInfo const &info, int32_t n_threads) {
|
||||||
CHECK_GE(info.group_ptr_.size(), 2);
|
CHECK_GE(info.group_ptr_.size(), 2);
|
||||||
uint32_t n_groups = info.group_ptr_.size() - 1;
|
uint32_t n_groups = info.group_ptr_.size() - 1;
|
||||||
float sum_auc = 0;
|
|
||||||
auto s_predts = common::Span<float const>{predts};
|
auto s_predts = common::Span<float const>{predts};
|
||||||
auto s_labels = info.labels_.ConstHostSpan();
|
auto s_labels = info.labels_.ConstHostSpan();
|
||||||
auto s_weights = info.weights_.ConstHostSpan();
|
auto s_weights = info.weights_.ConstHostSpan();
|
||||||
|
|
||||||
std::atomic<uint32_t> invalid_groups{0};
|
std::atomic<uint32_t> invalid_groups{0};
|
||||||
dmlc::OMPException omp_handler;
|
|
||||||
|
|
||||||
#pragma omp parallel for reduction(+:sum_auc) num_threads(n_threads)
|
std::vector<double> auc_tloc(n_threads, 0);
|
||||||
for (omp_ulong g = 1; g < info.group_ptr_.size(); ++g) {
|
common::ParallelFor(n_groups, n_threads, [&](size_t g) {
|
||||||
omp_handler.Run([&]() {
|
g += 1; // indexing needs to start from 1
|
||||||
size_t cnt = info.group_ptr_[g] - info.group_ptr_[g - 1];
|
size_t cnt = info.group_ptr_[g] - info.group_ptr_[g - 1];
|
||||||
float w = s_weights.empty() ? 1.0f : s_weights[g - 1];
|
float w = s_weights.empty() ? 1.0f : s_weights[g - 1];
|
||||||
auto g_predts = s_predts.subspan(info.group_ptr_[g - 1], cnt);
|
auto g_predts = s_predts.subspan(info.group_ptr_[g - 1], cnt);
|
||||||
auto g_labels = s_labels.subspan(info.group_ptr_[g - 1], cnt);
|
auto g_labels = s_labels.subspan(info.group_ptr_[g - 1], cnt);
|
||||||
float auc;
|
float auc;
|
||||||
if (is_roc && g_labels.size() < 3) {
|
if (is_roc && g_labels.size() < 3) {
|
||||||
// With 2 documents, there's only 1 comparison can be made. So either
|
// With 2 documents, there's only 1 comparison can be made. So either
|
||||||
// TP or FP will be zero.
|
// TP or FP will be zero.
|
||||||
|
invalid_groups++;
|
||||||
|
auc = 0;
|
||||||
|
} else {
|
||||||
|
if (is_roc) {
|
||||||
|
auc = GroupRankingROC(g_predts, g_labels, w);
|
||||||
|
} else {
|
||||||
|
auc = std::get<2>(BinaryPRAUC(g_predts, g_labels, OptionalWeights{w}));
|
||||||
|
}
|
||||||
|
if (std::isnan(auc)) {
|
||||||
invalid_groups++;
|
invalid_groups++;
|
||||||
auc = 0;
|
auc = 0;
|
||||||
} else {
|
|
||||||
if (is_roc) {
|
|
||||||
auc = GroupRankingROC(g_predts, g_labels, w);
|
|
||||||
} else {
|
|
||||||
auc = std::get<2>(BinaryPRAUC(g_predts, g_labels, OptionalWeights{w}));
|
|
||||||
}
|
|
||||||
if (std::isnan(auc)) {
|
|
||||||
invalid_groups++;
|
|
||||||
auc = 0;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
sum_auc += auc;
|
}
|
||||||
});
|
auc_tloc[omp_get_thread_num()] += auc;
|
||||||
}
|
});
|
||||||
omp_handler.Rethrow();
|
float sum_auc = std::accumulate(auc_tloc.cbegin(), auc_tloc.cend(), 0.0);
|
||||||
|
|
||||||
return std::make_pair(sum_auc, n_groups - invalid_groups);
|
return std::make_pair(sum_auc, n_groups - invalid_groups);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user