[coll] Pass context to various functions. (#9772)

* [coll] Pass context to various functions.

In the future, the `Context` object would be required for collective operations, this PR
passes the context object to some required functions to prepare for swapping out the
implementation.
This commit is contained in:
Jiaming Yuan
2023-11-08 09:54:05 +08:00
committed by GitHub
parent 6c0a190f6d
commit 06bdc15e9b
45 changed files with 275 additions and 255 deletions

View File

@@ -199,9 +199,9 @@ void Transpose(common::Span<float const> in, common::Span<float> out, size_t m,
});
}
double ScaleClasses(common::Span<double> results, common::Span<double> local_area,
common::Span<double> tp, common::Span<double> auc, size_t n_classes) {
dh::XGBDeviceAllocator<char> alloc;
double ScaleClasses(Context const *ctx, common::Span<double> results,
common::Span<double> local_area, common::Span<double> tp,
common::Span<double> auc, size_t n_classes) {
if (collective::IsDistributed()) {
int32_t device = dh::CurrentDevice();
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device);
@@ -218,8 +218,8 @@ double ScaleClasses(common::Span<double> results, common::Span<double> local_are
double tp_sum;
double auc_sum;
thrust::tie(auc_sum, tp_sum) =
thrust::reduce(thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes,
Pair{0.0, 0.0}, PairPlus<double, double>{});
thrust::reduce(ctx->CUDACtx()->CTP(), reduce_in, reduce_in + n_classes, Pair{0.0, 0.0},
PairPlus<double, double>{});
if (tp_sum != 0 && !std::isnan(auc_sum)) {
auc_sum /= tp_sum;
} else {
@@ -309,10 +309,10 @@ void SegmentedReduceAUC(common::Span<size_t const> d_unique_idx,
* up each class in all kernels.
*/
template <bool scale, typename Fn>
double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
double GPUMultiClassAUCOVR(Context const *ctx, MetaInfo const &info,
common::Span<uint32_t> d_class_ptr, size_t n_classes,
std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
dh::safe_cuda(cudaSetDevice(device.ordinal));
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
/**
* Sorted idx
*/
@@ -320,7 +320,7 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
// Index is sorted within class.
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
auto labels = info.labels.View(device);
auto labels = info.labels.View(ctx->Device());
auto weights = info.weights_.ConstDeviceSpan();
size_t n_samples = labels.Shape(0);
@@ -328,12 +328,11 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
if (n_samples == 0) {
dh::TemporaryArray<double> resutls(n_classes * 4, 0.0f);
auto d_results = dh::ToSpan(resutls);
dh::LaunchN(n_classes * 4,
[=] XGBOOST_DEVICE(size_t i) { d_results[i] = 0.0f; });
dh::LaunchN(n_classes * 4, [=] XGBOOST_DEVICE(size_t i) { d_results[i] = 0.0f; });
auto local_area = d_results.subspan(0, n_classes);
auto tp = d_results.subspan(2 * n_classes, n_classes);
auto auc = d_results.subspan(3 * n_classes, n_classes);
return ScaleClasses(d_results, local_area, tp, auc, n_classes);
return ScaleClasses(ctx, d_results, local_area, tp, auc, n_classes);
}
/**
@@ -437,7 +436,7 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
tp[c] = 1.0f;
}
});
return ScaleClasses(d_results, local_area, tp, auc, n_classes);
return ScaleClasses(ctx, d_results, local_area, tp, auc, n_classes);
}
void MultiClassSortedIdx(Context const *ctx, common::Span<float const> predts,
@@ -472,8 +471,7 @@ double GPUMultiClassROCAUC(Context const *ctx, common::Span<float const> predts,
size_t /*class_id*/) {
return TrapezoidArea(fp_prev, fp, tp_prev, tp);
};
return GPUMultiClassAUCOVR<true>(info, ctx->Device(), dh::ToSpan(class_ptr), n_classes, cache,
fn);
return GPUMultiClassAUCOVR<true>(ctx, info, dh::ToSpan(class_ptr), n_classes, cache, fn);
}
namespace {
@@ -697,7 +695,7 @@ double GPUMultiClassPRAUC(Context const *ctx, common::Span<float const> predts,
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
d_totals[class_id].first);
};
return GPUMultiClassAUCOVR<false>(info, ctx->Device(), d_class_ptr, n_classes, cache, fn);
return GPUMultiClassAUCOVR<false>(ctx, info, d_class_ptr, n_classes, cache, fn);
}
template <typename Fn>

View File

@@ -215,7 +215,7 @@ struct EvalError {
has_param_ = false;
}
}
const char *Name() const {
[[nodiscard]] const char *Name() const {
static thread_local std::string name;
if (has_param_) {
std::ostringstream os;
@@ -228,7 +228,7 @@ struct EvalError {
}
}
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
[[nodiscard]] XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
// assume label is in [0,1]
return pred > threshold_ ? 1.0f - label : label;
}
@@ -370,7 +370,7 @@ struct EvalEWiseBase : public MetricNoCache {
return Policy::GetFinal(dat[0], dat[1]);
}
const char* Name() const override { return policy_.Name(); }
[[nodiscard]] const char* Name() const override { return policy_.Name(); }
private:
Policy policy_;

View File

@@ -162,7 +162,7 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig {
return collective::GlobalRatio(info, sum_metric, static_cast<double>(ngroups));
}
const char* Name() const override {
[[nodiscard]] const char* Name() const override {
return name.c_str();
}
@@ -294,7 +294,7 @@ class EvalRankWithCache : public Metric {
};
namespace {
double Finalize(MetaInfo const& info, double score, double sw) {
double Finalize(Context const*, MetaInfo const& info, double score, double sw) {
std::array<double, 2> dat{score, sw};
collective::GlobalSum(info, &dat);
std::tie(score, sw) = std::tuple_cat(dat);
@@ -323,7 +323,7 @@ class EvalPrecision : public EvalRankWithCache<ltr::PreCache> {
if (ctx_->IsCUDA()) {
auto pre = cuda_impl::PreScore(ctx_, info, predt, p_cache);
return Finalize(info, pre.Residue(), pre.Weights());
return Finalize(ctx_, info, pre.Residue(), pre.Weights());
}
auto gptr = p_cache->DataGroupPtr(ctx_);
@@ -352,7 +352,7 @@ class EvalPrecision : public EvalRankWithCache<ltr::PreCache> {
}
auto sum = std::accumulate(pre.cbegin(), pre.cend(), 0.0);
return Finalize(info, sum, sw);
return Finalize(ctx_, info, sum, sw);
}
};
@@ -369,7 +369,7 @@ class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
std::shared_ptr<ltr::NDCGCache> p_cache) override {
if (ctx_->IsCUDA()) {
auto ndcg = cuda_impl::NDCGScore(ctx_, info, preds, minus_, p_cache);
return Finalize(info, ndcg.Residue(), ndcg.Weights());
return Finalize(ctx_, info, ndcg.Residue(), ndcg.Weights());
}
// group local ndcg
@@ -415,7 +415,7 @@ class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
sum_w = std::accumulate(weights.weights.cbegin(), weights.weights.cend(), 0.0);
}
auto ndcg = std::accumulate(linalg::cbegin(ndcg_gloc), linalg::cend(ndcg_gloc), 0.0);
return Finalize(info, ndcg, sum_w);
return Finalize(ctx_, info, ndcg, sum_w);
}
};
@@ -427,7 +427,7 @@ class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
std::shared_ptr<ltr::MAPCache> p_cache) override {
if (ctx_->IsCUDA()) {
auto map = cuda_impl::MAPScore(ctx_, info, predt, minus_, p_cache);
return Finalize(info, map.Residue(), map.Weights());
return Finalize(ctx_, info, map.Residue(), map.Weights());
}
auto gptr = p_cache->DataGroupPtr(ctx_);
@@ -469,7 +469,7 @@ class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
sw += weight[i];
}
auto sum = std::accumulate(map_gloc.cbegin(), map_gloc.cend(), 0.0);
return Finalize(info, sum, sw);
return Finalize(ctx_, info, sum, sw);
}
};

View File

@@ -217,7 +217,7 @@ struct EvalEWiseSurvivalBase : public MetricNoCache {
return Policy::GetFinal(dat[0], dat[1]);
}
const char* Name() const override {
[[nodiscard]] const char* Name() const override {
return policy_.Name();
}