[coll] Pass context to various functions. (#9772)
* [coll] Pass context to various functions. In the future, the `Context` object would be required for collective operations, this PR passes the context object to some required functions to prepare for swapping out the implementation.
This commit is contained in:
@@ -199,9 +199,9 @@ void Transpose(common::Span<float const> in, common::Span<float> out, size_t m,
|
||||
});
|
||||
}
|
||||
|
||||
double ScaleClasses(common::Span<double> results, common::Span<double> local_area,
|
||||
common::Span<double> tp, common::Span<double> auc, size_t n_classes) {
|
||||
dh::XGBDeviceAllocator<char> alloc;
|
||||
double ScaleClasses(Context const *ctx, common::Span<double> results,
|
||||
common::Span<double> local_area, common::Span<double> tp,
|
||||
common::Span<double> auc, size_t n_classes) {
|
||||
if (collective::IsDistributed()) {
|
||||
int32_t device = dh::CurrentDevice();
|
||||
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device);
|
||||
@@ -218,8 +218,8 @@ double ScaleClasses(common::Span<double> results, common::Span<double> local_are
|
||||
double tp_sum;
|
||||
double auc_sum;
|
||||
thrust::tie(auc_sum, tp_sum) =
|
||||
thrust::reduce(thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes,
|
||||
Pair{0.0, 0.0}, PairPlus<double, double>{});
|
||||
thrust::reduce(ctx->CUDACtx()->CTP(), reduce_in, reduce_in + n_classes, Pair{0.0, 0.0},
|
||||
PairPlus<double, double>{});
|
||||
if (tp_sum != 0 && !std::isnan(auc_sum)) {
|
||||
auc_sum /= tp_sum;
|
||||
} else {
|
||||
@@ -309,10 +309,10 @@ void SegmentedReduceAUC(common::Span<size_t const> d_unique_idx,
|
||||
* up each class in all kernels.
|
||||
*/
|
||||
template <bool scale, typename Fn>
|
||||
double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
|
||||
double GPUMultiClassAUCOVR(Context const *ctx, MetaInfo const &info,
|
||||
common::Span<uint32_t> d_class_ptr, size_t n_classes,
|
||||
std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
|
||||
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
||||
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
|
||||
/**
|
||||
* Sorted idx
|
||||
*/
|
||||
@@ -320,7 +320,7 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
|
||||
// Index is sorted within class.
|
||||
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
|
||||
|
||||
auto labels = info.labels.View(device);
|
||||
auto labels = info.labels.View(ctx->Device());
|
||||
auto weights = info.weights_.ConstDeviceSpan();
|
||||
|
||||
size_t n_samples = labels.Shape(0);
|
||||
@@ -328,12 +328,11 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
|
||||
if (n_samples == 0) {
|
||||
dh::TemporaryArray<double> resutls(n_classes * 4, 0.0f);
|
||||
auto d_results = dh::ToSpan(resutls);
|
||||
dh::LaunchN(n_classes * 4,
|
||||
[=] XGBOOST_DEVICE(size_t i) { d_results[i] = 0.0f; });
|
||||
dh::LaunchN(n_classes * 4, [=] XGBOOST_DEVICE(size_t i) { d_results[i] = 0.0f; });
|
||||
auto local_area = d_results.subspan(0, n_classes);
|
||||
auto tp = d_results.subspan(2 * n_classes, n_classes);
|
||||
auto auc = d_results.subspan(3 * n_classes, n_classes);
|
||||
return ScaleClasses(d_results, local_area, tp, auc, n_classes);
|
||||
return ScaleClasses(ctx, d_results, local_area, tp, auc, n_classes);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -437,7 +436,7 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
|
||||
tp[c] = 1.0f;
|
||||
}
|
||||
});
|
||||
return ScaleClasses(d_results, local_area, tp, auc, n_classes);
|
||||
return ScaleClasses(ctx, d_results, local_area, tp, auc, n_classes);
|
||||
}
|
||||
|
||||
void MultiClassSortedIdx(Context const *ctx, common::Span<float const> predts,
|
||||
@@ -472,8 +471,7 @@ double GPUMultiClassROCAUC(Context const *ctx, common::Span<float const> predts,
|
||||
size_t /*class_id*/) {
|
||||
return TrapezoidArea(fp_prev, fp, tp_prev, tp);
|
||||
};
|
||||
return GPUMultiClassAUCOVR<true>(info, ctx->Device(), dh::ToSpan(class_ptr), n_classes, cache,
|
||||
fn);
|
||||
return GPUMultiClassAUCOVR<true>(ctx, info, dh::ToSpan(class_ptr), n_classes, cache, fn);
|
||||
}
|
||||
|
||||
namespace {
|
||||
@@ -697,7 +695,7 @@ double GPUMultiClassPRAUC(Context const *ctx, common::Span<float const> predts,
|
||||
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
|
||||
d_totals[class_id].first);
|
||||
};
|
||||
return GPUMultiClassAUCOVR<false>(info, ctx->Device(), d_class_ptr, n_classes, cache, fn);
|
||||
return GPUMultiClassAUCOVR<false>(ctx, info, d_class_ptr, n_classes, cache, fn);
|
||||
}
|
||||
|
||||
template <typename Fn>
|
||||
|
||||
@@ -215,7 +215,7 @@ struct EvalError {
|
||||
has_param_ = false;
|
||||
}
|
||||
}
|
||||
const char *Name() const {
|
||||
[[nodiscard]] const char *Name() const {
|
||||
static thread_local std::string name;
|
||||
if (has_param_) {
|
||||
std::ostringstream os;
|
||||
@@ -228,7 +228,7 @@ struct EvalError {
|
||||
}
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
|
||||
[[nodiscard]] XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
|
||||
// assume label is in [0,1]
|
||||
return pred > threshold_ ? 1.0f - label : label;
|
||||
}
|
||||
@@ -370,7 +370,7 @@ struct EvalEWiseBase : public MetricNoCache {
|
||||
return Policy::GetFinal(dat[0], dat[1]);
|
||||
}
|
||||
|
||||
const char* Name() const override { return policy_.Name(); }
|
||||
[[nodiscard]] const char* Name() const override { return policy_.Name(); }
|
||||
|
||||
private:
|
||||
Policy policy_;
|
||||
|
||||
@@ -162,7 +162,7 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig {
|
||||
return collective::GlobalRatio(info, sum_metric, static_cast<double>(ngroups));
|
||||
}
|
||||
|
||||
const char* Name() const override {
|
||||
[[nodiscard]] const char* Name() const override {
|
||||
return name.c_str();
|
||||
}
|
||||
|
||||
@@ -294,7 +294,7 @@ class EvalRankWithCache : public Metric {
|
||||
};
|
||||
|
||||
namespace {
|
||||
double Finalize(MetaInfo const& info, double score, double sw) {
|
||||
double Finalize(Context const*, MetaInfo const& info, double score, double sw) {
|
||||
std::array<double, 2> dat{score, sw};
|
||||
collective::GlobalSum(info, &dat);
|
||||
std::tie(score, sw) = std::tuple_cat(dat);
|
||||
@@ -323,7 +323,7 @@ class EvalPrecision : public EvalRankWithCache<ltr::PreCache> {
|
||||
|
||||
if (ctx_->IsCUDA()) {
|
||||
auto pre = cuda_impl::PreScore(ctx_, info, predt, p_cache);
|
||||
return Finalize(info, pre.Residue(), pre.Weights());
|
||||
return Finalize(ctx_, info, pre.Residue(), pre.Weights());
|
||||
}
|
||||
|
||||
auto gptr = p_cache->DataGroupPtr(ctx_);
|
||||
@@ -352,7 +352,7 @@ class EvalPrecision : public EvalRankWithCache<ltr::PreCache> {
|
||||
}
|
||||
|
||||
auto sum = std::accumulate(pre.cbegin(), pre.cend(), 0.0);
|
||||
return Finalize(info, sum, sw);
|
||||
return Finalize(ctx_, info, sum, sw);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -369,7 +369,7 @@ class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
|
||||
std::shared_ptr<ltr::NDCGCache> p_cache) override {
|
||||
if (ctx_->IsCUDA()) {
|
||||
auto ndcg = cuda_impl::NDCGScore(ctx_, info, preds, minus_, p_cache);
|
||||
return Finalize(info, ndcg.Residue(), ndcg.Weights());
|
||||
return Finalize(ctx_, info, ndcg.Residue(), ndcg.Weights());
|
||||
}
|
||||
|
||||
// group local ndcg
|
||||
@@ -415,7 +415,7 @@ class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
|
||||
sum_w = std::accumulate(weights.weights.cbegin(), weights.weights.cend(), 0.0);
|
||||
}
|
||||
auto ndcg = std::accumulate(linalg::cbegin(ndcg_gloc), linalg::cend(ndcg_gloc), 0.0);
|
||||
return Finalize(info, ndcg, sum_w);
|
||||
return Finalize(ctx_, info, ndcg, sum_w);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -427,7 +427,7 @@ class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
|
||||
std::shared_ptr<ltr::MAPCache> p_cache) override {
|
||||
if (ctx_->IsCUDA()) {
|
||||
auto map = cuda_impl::MAPScore(ctx_, info, predt, minus_, p_cache);
|
||||
return Finalize(info, map.Residue(), map.Weights());
|
||||
return Finalize(ctx_, info, map.Residue(), map.Weights());
|
||||
}
|
||||
|
||||
auto gptr = p_cache->DataGroupPtr(ctx_);
|
||||
@@ -469,7 +469,7 @@ class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
|
||||
sw += weight[i];
|
||||
}
|
||||
auto sum = std::accumulate(map_gloc.cbegin(), map_gloc.cend(), 0.0);
|
||||
return Finalize(info, sum, sw);
|
||||
return Finalize(ctx_, info, sum, sw);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -217,7 +217,7 @@ struct EvalEWiseSurvivalBase : public MetricNoCache {
|
||||
return Policy::GetFinal(dat[0], dat[1]);
|
||||
}
|
||||
|
||||
const char* Name() const override {
|
||||
[[nodiscard]] const char* Name() const override {
|
||||
return policy_.Name();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user