Specify the number of threads for parallel sort. (#8735)

* Specify the number of threads for parallel sort.

- Pass context object into argsort.
- Replace macros with inline functions.
This commit is contained in:
Jiaming Yuan
2023-02-16 00:20:19 +08:00
committed by GitHub
parent c7c485d052
commit 282b1729da
24 changed files with 254 additions and 143 deletions

View File

@@ -14,9 +14,11 @@
#include <utility>
#include <vector>
#include "../common/algorithm.h" // ArgSort
#include "../common/math.h"
#include "../common/optional_weight.h" // OptionalWeights
#include "metric_common.h" // MetricNoCache
#include "xgboost/context.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/linalg.h"
#include "xgboost/metric.h"
@@ -77,9 +79,8 @@ BinaryAUC(common::Span<float const> predts, linalg::VectorView<float const> labe
* Machine Learning Models
*/
template <typename BinaryAUC>
double MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
size_t n_classes, int32_t n_threads,
BinaryAUC &&binary_auc) {
double MultiClassOVR(Context const *ctx, common::Span<float const> predts, MetaInfo const &info,
size_t n_classes, int32_t n_threads, BinaryAUC &&binary_auc) {
CHECK_NE(n_classes, 0);
auto const labels = info.labels.View(Context::kCpuId);
if (labels.Shape(0) != 0) {
@@ -108,7 +109,7 @@ double MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
}
double fp;
std::tie(fp, tp(c), auc(c)) =
binary_auc(proba, linalg::MakeVec(response.data(), response.size(), -1), weights);
binary_auc(ctx, proba, linalg::MakeVec(response.data(), response.size(), -1), weights);
local_area(c) = fp * tp(c);
});
}
@@ -139,23 +140,26 @@ double MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
return auc_sum;
}
std::tuple<double, double, double> BinaryROCAUC(common::Span<float const> predts,
std::tuple<double, double, double> BinaryROCAUC(Context const *ctx,
common::Span<float const> predts,
linalg::VectorView<float const> labels,
common::OptionalWeights weights) {
auto const sorted_idx = common::ArgSort<size_t>(predts, std::greater<>{});
auto const sorted_idx =
common::ArgSort<size_t>(ctx, predts.data(), predts.data() + predts.size(), std::greater<>{});
return BinaryAUC(predts, labels, weights, sorted_idx, TrapezoidArea);
}
/**
* Calculate AUC for 1 ranking group;
*/
double GroupRankingROC(common::Span<float const> predts,
double GroupRankingROC(Context const* ctx, common::Span<float const> predts,
linalg::VectorView<float const> labels, float w) {
// on ranking, we just count all pairs.
double auc{0};
// argsort doesn't support tensor input yet.
auto raw_labels = labels.Values().subspan(0, labels.Size());
auto const sorted_idx = common::ArgSort<size_t>(raw_labels, std::greater<>{});
auto const sorted_idx = common::ArgSort<size_t>(
ctx, raw_labels.data(), raw_labels.data() + raw_labels.size(), std::greater<>{});
w = common::Sqr(w);
double sum_w = 0.0f;
@@ -185,10 +189,11 @@ double GroupRankingROC(common::Span<float const> predts,
*
* https://doi.org/10.1371/journal.pone.0092209
*/
std::tuple<double, double, double> BinaryPRAUC(common::Span<float const> predts,
std::tuple<double, double, double> BinaryPRAUC(Context const *ctx, common::Span<float const> predts,
linalg::VectorView<float const> labels,
common::OptionalWeights weights) {
auto const sorted_idx = common::ArgSort<size_t>(predts, std::greater<>{});
auto const sorted_idx =
common::ArgSort<size_t>(ctx, predts.data(), predts.data() + predts.size(), std::greater<>{});
double total_pos{0}, total_neg{0};
for (size_t i = 0; i < labels.Size(); ++i) {
auto w = weights[i];
@@ -211,9 +216,8 @@ std::tuple<double, double, double> BinaryPRAUC(common::Span<float const> predts,
* Cast LTR problem to binary classification problem by comparing pairs.
*/
template <bool is_roc>
std::pair<double, uint32_t> RankingAUC(std::vector<float> const &predts,
MetaInfo const &info,
int32_t n_threads) {
std::pair<double, uint32_t> RankingAUC(Context const *ctx, std::vector<float> const &predts,
MetaInfo const &info, int32_t n_threads) {
CHECK_GE(info.group_ptr_.size(), 2);
uint32_t n_groups = info.group_ptr_.size() - 1;
auto s_predts = common::Span<float const>{predts};
@@ -237,9 +241,9 @@ std::pair<double, uint32_t> RankingAUC(std::vector<float> const &predts,
auc = 0;
} else {
if (is_roc) {
auc = GroupRankingROC(g_predts, g_labels, w);
auc = GroupRankingROC(ctx, g_predts, g_labels, w);
} else {
auc = std::get<2>(BinaryPRAUC(g_predts, g_labels, common::OptionalWeights{w}));
auc = std::get<2>(BinaryPRAUC(ctx, g_predts, g_labels, common::OptionalWeights{w}));
}
if (std::isnan(auc)) {
invalid_groups++;
@@ -344,7 +348,7 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
auto n_threads = ctx_->Threads();
if (ctx_->gpu_id == Context::kCpuId) {
std::tie(auc, valid_groups) =
RankingAUC<true>(predts.ConstHostVector(), info, n_threads);
RankingAUC<true>(ctx_, predts.ConstHostVector(), info, n_threads);
} else {
std::tie(auc, valid_groups) =
GPURankingAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_);
@@ -358,8 +362,7 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
auto n_threads = ctx_->Threads();
CHECK_NE(n_classes, 0);
if (ctx_->gpu_id == Context::kCpuId) {
auc = MultiClassOVR(predts.ConstHostVector(), info, n_classes, n_threads,
BinaryROCAUC);
auc = MultiClassOVR(ctx_, predts.ConstHostVector(), info, n_classes, n_threads, BinaryROCAUC);
} else {
auc = GPUMultiClassROCAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_, n_classes);
}
@@ -370,9 +373,9 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
EvalBinary(HostDeviceVector<float> const &predts, MetaInfo const &info) {
double fp, tp, auc;
if (ctx_->gpu_id == Context::kCpuId) {
std::tie(fp, tp, auc) =
BinaryROCAUC(predts.ConstHostVector(), info.labels.HostView().Slice(linalg::All(), 0),
common::OptionalWeights{info.weights_.ConstHostSpan()});
std::tie(fp, tp, auc) = BinaryROCAUC(ctx_, predts.ConstHostVector(),
info.labels.HostView().Slice(linalg::All(), 0),
common::OptionalWeights{info.weights_.ConstHostSpan()});
} else {
std::tie(fp, tp, auc) = GPUBinaryROCAUC(predts.ConstDeviceSpan(), info,
ctx_->gpu_id, &this->d_cache_);
@@ -422,7 +425,7 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
double pr, re, auc;
if (ctx_->gpu_id == Context::kCpuId) {
std::tie(pr, re, auc) =
BinaryPRAUC(predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0),
BinaryPRAUC(ctx_, predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0),
common::OptionalWeights{info.weights_.ConstHostSpan()});
} else {
std::tie(pr, re, auc) = GPUBinaryPRAUC(predts.ConstDeviceSpan(), info,
@@ -435,8 +438,7 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
size_t n_classes) {
if (ctx_->gpu_id == Context::kCpuId) {
auto n_threads = this->ctx_->Threads();
return MultiClassOVR(predts.ConstHostSpan(), info, n_classes, n_threads,
BinaryPRAUC);
return MultiClassOVR(ctx_, predts.ConstHostSpan(), info, n_classes, n_threads, BinaryPRAUC);
} else {
return GPUMultiClassPRAUC(ctx_, predts.ConstDeviceSpan(), info, &d_cache_, n_classes);
}
@@ -453,7 +455,7 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
InvalidLabels();
}
std::tie(auc, valid_groups) =
RankingAUC<false>(predts.ConstHostVector(), info, n_threads);
RankingAUC<false>(ctx_, predts.ConstHostVector(), info, n_threads);
} else {
std::tie(auc, valid_groups) =
GPURankingPRAUC(ctx_, predts.ConstDeviceSpan(), info, &d_cache_);

View File

@@ -27,6 +27,7 @@
#include <vector>
#include "../collective/communicator-inl.h"
#include "../common/algorithm.h" // Sort
#include "../common/math.h"
#include "../common/ranking_utils.h" // MakeMetricName
#include "../common/threading_utils.h"
@@ -113,7 +114,7 @@ struct EvalAMS : public MetricNoCache {
const auto &h_preds = preds.ConstHostVector();
common::ParallelFor(ndata, ctx_->Threads(),
[&](bst_omp_uint i) { rec[i] = std::make_pair(h_preds[i], i); });
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
common::Sort(ctx_, rec.begin(), rec.end(), common::CmpFirst);
auto ntop = static_cast<unsigned>(ratio_ * ndata);
if (ntop == 0) ntop = ndata;
const double br = 10.0;
@@ -330,7 +331,7 @@ struct EvalCox : public MetricNoCache {
using namespace std; // NOLINT(*)
const auto ndata = static_cast<bst_omp_uint>(info.labels.Size());
const auto &label_order = info.LabelAbsSort();
const auto &label_order = info.LabelAbsSort(ctx_);
// pre-compute a sum for the denominator
double exp_p_sum = 0; // we use double because we might need the precision with large datasets