From 5ac233280e1218fcf9de011fbbbe7841402d9866 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 28 Feb 2024 03:12:42 +0800 Subject: [PATCH] Require context in aggregators. (#10075) --- .clang-format | 2 +- include/xgboost/collective/result.h | 12 ++++++-- src/collective/aggregator.h | 46 ++++++++++++++++------------- src/common/quantile.cc | 34 +++++++++++++-------- src/learner.cc | 6 ++-- src/metric/auc.cc | 10 ++++--- src/metric/elementwise_metric.cu | 37 ++++++++++++----------- src/metric/metric_common.h | 7 ++--- src/metric/multiclass_metric.cu | 20 ++++++------- src/metric/rank_metric.cc | 11 +++---- src/metric/rank_metric.cu | 2 +- src/metric/survival_metric.cu | 28 ++++++++---------- src/objective/adaptive.cc | 8 ++--- src/objective/adaptive.cu | 9 +++--- src/objective/adaptive.h | 23 ++++++++------- src/objective/quantile_obj.cu | 6 ++-- src/objective/regression_obj.cu | 10 +++++-- src/tree/fit_stump.cc | 11 ++++--- src/tree/fit_stump.cu | 17 +++++------ src/tree/gpu_hist/histogram.cu | 11 ++++--- src/tree/updater_approx.cc | 7 +++-- src/tree/updater_gpu_hist.cu | 6 ++-- src/tree/updater_quantile_hist.cc | 11 +++++-- 23 files changed, 190 insertions(+), 144 deletions(-) diff --git a/.clang-format b/.clang-format index 0984d5a7b..737cf9006 100644 --- a/.clang-format +++ b/.clang-format @@ -17,7 +17,7 @@ AllowShortEnumsOnASingleLine: true AllowShortBlocksOnASingleLine: Never AllowShortCaseLabelsOnASingleLine: false AllowShortFunctionsOnASingleLine: All -AllowShortLambdasOnASingleLine: All +AllowShortLambdasOnASingleLine: Inline AllowShortIfStatementsOnASingleLine: WithoutElse AllowShortLoopsOnASingleLine: true AlwaysBreakAfterDefinitionReturnType: None diff --git a/include/xgboost/collective/result.h b/include/xgboost/collective/result.h index 507171dd4..919d3a902 100644 --- a/include/xgboost/collective/result.h +++ b/include/xgboost/collective/result.h @@ -1,8 +1,10 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once +#include + #include // for unique_ptr #include // for stringstream #include // for stack @@ -160,10 +162,16 @@ struct Result { // We don't have monad, a simple helper would do. template -Result operator<<(Result&& r, Fn&& fn) { +[[nodiscard]] Result operator<<(Result&& r, Fn&& fn) { if (!r.OK()) { return std::forward(r); } return fn(); } + +inline void SafeColl(Result const& rc) { + if (!rc.OK()) { + LOG(FATAL) << rc.Report(); + } +} } // namespace xgboost::collective diff --git a/src/collective/aggregator.h b/src/collective/aggregator.h index f2a9ff528..8a5b31c36 100644 --- a/src/collective/aggregator.h +++ b/src/collective/aggregator.h @@ -1,22 +1,21 @@ /** - * Copyright 2023 by XGBoost contributors + * Copyright 2023-2024, XGBoost contributors * * Higher level functions built on top the Communicator API, taking care of behavioral differences * between row-split vs column-split distributed training, and horizontal vs vertical federated * learning. */ #pragma once -#include - #include #include #include #include #include "communicator-inl.h" +#include "xgboost/collective/result.h" // for Result +#include "xgboost/data.h" // for MetaINfo -namespace xgboost { -namespace collective { +namespace xgboost::collective { /** * @brief Apply the given function where the labels are. @@ -31,15 +30,16 @@ namespace collective { * @param size The size of the buffer. * @param function The function used to calculate the results. */ -template -void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&& function) { +template +void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::size_t size, + FN&& function) { if (info.IsVerticalFederated()) { // We assume labels are only available on worker 0, so the calculation is done there and result // broadcast to other workers. std::string message; if (collective::GetRank() == 0) { try { - std::forward(function)(); + std::forward(function)(); } catch (dmlc::Error& e) { message = e.what(); } @@ -52,7 +52,7 @@ void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&& LOG(FATAL) << &message[0]; } } else { - std::forward(function)(); + std::forward(function)(); } } @@ -70,7 +70,8 @@ void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&& * @param function The function used to calculate the results. */ template -void ApplyWithLabels(MetaInfo const& info, HostDeviceVector* result, Function&& function) { +void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector* result, + Function&& function) { if (info.IsVerticalFederated()) { // We assume labels are only available on worker 0, so the calculation is done there and result // broadcast to other workers. @@ -114,7 +115,9 @@ void ApplyWithLabels(MetaInfo const& info, HostDeviceVector* result, Function * @return The global max of the input. */ template -T GlobalMax(MetaInfo const& info, T value) { +std::enable_if_t, T> GlobalMax(Context const*, + MetaInfo const& info, + T value) { if (info.IsRowSplit()) { collective::Allreduce(&value, 1); } @@ -132,16 +135,18 @@ T GlobalMax(MetaInfo const& info, T value) { * @param values Pointer to the inputs to sum. * @param size Number of values to sum. */ -template -void GlobalSum(MetaInfo const& info, T* values, size_t size) { +template +[[nodiscard]] Result GlobalSum(Context const*, MetaInfo const& info, + linalg::TensorView values) { if (info.IsRowSplit()) { - collective::Allreduce(values, size); + collective::Allreduce(values.Values().data(), values.Size()); } + return Success(); } template -void GlobalSum(MetaInfo const& info, Container* values) { - GlobalSum(info, values->data(), values->size()); +[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info, Container* values) { + return GlobalSum(ctx, info, values->data(), values->size()); } /** @@ -157,9 +162,10 @@ void GlobalSum(MetaInfo const& info, Container* values) { * @return The global ratio of the two inputs. */ template -T GlobalRatio(MetaInfo const& info, T dividend, T divisor) { +T GlobalRatio(Context const* ctx, MetaInfo const& info, T dividend, T divisor) { std::array results{dividend, divisor}; - GlobalSum(info, &results); + auto rc = GlobalSum(ctx, info, linalg::MakeVec(results.data(), results.size())); + collective::SafeColl(rc); std::tie(dividend, divisor) = std::tuple_cat(results); if (divisor <= 0) { return std::numeric_limits::quiet_NaN(); @@ -167,6 +173,4 @@ T GlobalRatio(MetaInfo const& info, T dividend, T divisor) { return dividend / divisor; } } - -} // namespace collective -} // namespace xgboost +} // namespace xgboost::collective diff --git a/src/common/quantile.cc b/src/common/quantile.cc index c74db99e4..e521fae69 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -1,5 +1,5 @@ -/*! - * Copyright 2020-2022 by XGBoost Contributors +/** + * Copyright 2020-2024, XGBoost Contributors */ #include "quantile.h" @@ -145,7 +145,7 @@ struct QuantileAllreduce { template void SketchContainerImpl::GatherSketchInfo( - Context const *, MetaInfo const &info, + Context const *ctx, MetaInfo const &info, std::vector const &reduced, std::vector *p_worker_segments, std::vector *p_sketches_scan, std::vector *p_global_sketches) { @@ -171,7 +171,9 @@ void SketchContainerImpl::GatherSketchInfo( std::partial_sum(sketch_size.cbegin(), sketch_size.cend(), sketches_scan.begin() + beg_scan + 1); // Gather all column pointers - collective::GlobalSum(info, sketches_scan.data(), sketches_scan.size()); + auto rc = + collective::GlobalSum(ctx, info, linalg::MakeVec(sketches_scan.data(), sketches_scan.size())); + collective::SafeColl(rc); for (int32_t i = 0; i < world; ++i) { size_t back = (i + 1) * (n_columns + 1) - 1; auto n_entries = sketches_scan.at(back); @@ -199,14 +201,15 @@ void SketchContainerImpl::GatherSketchInfo( static_assert(sizeof(typename WQSketch::Entry) / 4 == sizeof(float), "Unexpected size of sketch entry."); - collective::GlobalSum( - info, - reinterpret_cast(global_sketches.data()), - global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float)); + rc = collective::GlobalSum( + ctx, info, + linalg::MakeVec(reinterpret_cast(global_sketches.data()), + global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float))); + collective::SafeColl(rc); } template -void SketchContainerImpl::AllreduceCategories(Context const*, MetaInfo const& info) { +void SketchContainerImpl::AllreduceCategories(Context const* ctx, MetaInfo const& info) { auto world_size = collective::GetWorldSize(); auto rank = collective::GetRank(); if (world_size == 1 || info.IsColumnSplit()) { @@ -226,7 +229,8 @@ void SketchContainerImpl::AllreduceCategories(Context const*, MetaInfo std::vector global_feat_ptrs(feature_ptr.size() * world_size, 0); size_t feat_begin = rank * feature_ptr.size(); // pointer to current worker std::copy(feature_ptr.begin(), feature_ptr.end(), global_feat_ptrs.begin() + feat_begin); - collective::GlobalSum(info, global_feat_ptrs.data(), global_feat_ptrs.size()); + auto rc = collective::GlobalSum( + ctx, info, linalg::MakeVec(global_feat_ptrs.data(), global_feat_ptrs.size())); // move all categories into a flatten vector to prepare for allreduce size_t total = feature_ptr.back(); @@ -239,7 +243,8 @@ void SketchContainerImpl::AllreduceCategories(Context const*, MetaInfo // indptr for indexing workers std::vector global_worker_ptr(world_size + 1, 0); global_worker_ptr[rank + 1] = total; // shift 1 to right for constructing the indptr - collective::GlobalSum(info, global_worker_ptr.data(), global_worker_ptr.size()); + rc = collective::GlobalSum(ctx, info, + linalg::MakeVec(global_worker_ptr.data(), global_worker_ptr.size())); std::partial_sum(global_worker_ptr.cbegin(), global_worker_ptr.cend(), global_worker_ptr.begin()); // total number of categories in all workers with all features auto gtotal = global_worker_ptr.back(); @@ -251,7 +256,8 @@ void SketchContainerImpl::AllreduceCategories(Context const*, MetaInfo CHECK_EQ(rank_size, total); std::copy(flatten.cbegin(), flatten.cend(), global_categories.begin() + rank_begin); // gather values from all workers. - collective::GlobalSum(info, global_categories.data(), global_categories.size()); + rc = collective::GlobalSum(ctx, info, + linalg::MakeVec(global_categories.data(), global_categories.size())); QuantileAllreduce allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs, categories_.size()}; ParallelFor(categories_.size(), n_threads_, [&](auto fidx) { @@ -293,7 +299,9 @@ void SketchContainerImpl::AllReduce( // Prune the intermediate num cuts for synchronization. std::vector global_column_size(columns_size_); - collective::GlobalSum(info, &global_column_size); + auto rc = collective::GlobalSum( + ctx, info, linalg::MakeVec(global_column_size.data(), global_column_size.size())); + collective::SafeColl(rc); ParallelFor(sketches_.size(), n_threads_, [&](size_t i) { int32_t intermediate_num_cuts = static_cast( diff --git a/src/learner.cc b/src/learner.cc index db72f7164..eed9dd5cd 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -1,5 +1,5 @@ /** - * Copyright 2014-2023 by XGBoost Contributors + * Copyright 2014-2024, XGBoost Contributors * \file learner.cc * \brief Implementation of learning algorithm. * \author Tianqi Chen @@ -846,7 +846,7 @@ class LearnerConfiguration : public Learner { void InitEstimation(MetaInfo const& info, linalg::Tensor* base_score) { base_score->Reshape(1); - collective::ApplyWithLabels(info, base_score->Data(), + collective::ApplyWithLabels(this->Ctx(), info, base_score->Data(), [&] { UsePtr(obj_)->InitEstimation(info, base_score); }); } }; @@ -1472,7 +1472,7 @@ class LearnerImpl : public LearnerIO { void GetGradient(HostDeviceVector const& preds, MetaInfo const& info, std::int32_t iter, linalg::Matrix* out_gpair) { out_gpair->Reshape(info.num_row_, this->learner_model_param_.OutputLength()); - collective::ApplyWithLabels(info, out_gpair->Data(), + collective::ApplyWithLabels(&ctx_, info, out_gpair->Data(), [&] { obj_->GetGradient(preds, info, iter, out_gpair); }); } diff --git a/src/metric/auc.cc b/src/metric/auc.cc index 4a8aa8a4b..212a3a027 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021-2023 by XGBoost Contributors + * Copyright 2021-2024, XGBoost Contributors */ #include "auc.h" @@ -112,7 +112,9 @@ double MultiClassOVR(Context const *ctx, common::Span predts, MetaI // we have 2 averages going in here, first is among workers, second is among // classes. allreduce sums up fp/tp auc for each class. - collective::GlobalSum(info, &results.Values()); + auto rc = collective::GlobalSum(ctx, info, results); + collective::SafeColl(rc); + double auc_sum{0}; double tp_sum{0}; for (size_t c = 0; c < n_classes; ++c) { @@ -286,7 +288,7 @@ class EvalAUC : public MetricNoCache { InvalidGroupAUC(); } - auc = collective::GlobalRatio(info, auc, static_cast(valid_groups)); + auc = collective::GlobalRatio(ctx_, info, auc, static_cast(valid_groups)); if (!std::isnan(auc)) { CHECK_LE(auc, 1) << "Total AUC across groups: " << auc * valid_groups << ", valid groups: " << valid_groups; @@ -307,7 +309,7 @@ class EvalAUC : public MetricNoCache { std::tie(fp, tp, auc) = static_cast(this)->EvalBinary(preds, info); } - auc = collective::GlobalRatio(info, auc, fp * tp); + auc = collective::GlobalRatio(ctx_, info, auc, fp * tp); if (!std::isnan(auc)) { CHECK_LE(auc, 1.0); } diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index f245f3e06..9c26011aa 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -1,5 +1,5 @@ /** - * Copyright 2015-2023 by XGBoost Contributors + * Copyright 2015-2024, XGBoost Contributors * \file elementwise_metric.cu * \brief evaluation metrics for elementwise binary or regression. * \author Kailong Chen, Tianqi Chen @@ -12,13 +12,14 @@ #include #include "../collective/communicator-inl.h" -#include "../common/common.h" // MetricNoCache +#include "../common/common.h" // MetricNoCache #include "../common/math.h" #include "../common/optional_weight.h" // OptionalWeights #include "../common/pseudo_huber.h" #include "../common/quantile_loss_utils.h" // QuantileLossParam #include "../common/threading_utils.h" #include "metric_common.h" +#include "xgboost/collective/result.h" // for SafeColl #include "xgboost/metric.h" #if defined(XGBOOST_USE_CUDA) @@ -30,8 +31,7 @@ #include "../common/device_helpers.cuh" #endif // XGBOOST_USE_CUDA -namespace xgboost { -namespace metric { +namespace xgboost::metric { // tag the this file, used by force static link later. DMLC_REGISTRY_FILE_TAG(elementwise_metric); @@ -199,7 +199,8 @@ class PseudoErrorLoss : public MetricNoCache { return std::make_tuple(v, wt); }); std::array dat{result.Residue(), result.Weights()}; - collective::GlobalSum(info, &dat); + auto rc = collective::GlobalSum(ctx_, info, linalg::MakeVec(dat.data(), dat.size())); + collective::SafeColl(rc); return EvalRowMAPE::GetFinal(dat[0], dat[1]); } }; @@ -243,11 +244,11 @@ struct EvalError { }; struct EvalPoissonNegLogLik { - const char *Name() const { + [[nodiscard]] const char *Name() const { return "poisson-nloglik"; } - XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const { + [[nodiscard]] XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const { const bst_float eps = 1e-16f; if (py < eps) py = eps; return common::LogGamma(y + 1.0f) + py - std::log(py) * y; @@ -266,9 +267,9 @@ struct EvalPoissonNegLogLik { * predt >= 0 */ struct EvalGammaDeviance { - const char *Name() const { return "gamma-deviance"; } + [[nodiscard]] const char *Name() const { return "gamma-deviance"; } - XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float predt) const { + [[nodiscard]] XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float predt) const { predt += kRtEps; label += kRtEps; return std::log(predt / label) + label / predt - 1; @@ -287,7 +288,7 @@ struct EvalGammaNLogLik { return "gamma-nloglik"; } - XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const { + [[nodiscard]] XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const { py = std::max(py, 1e-6f); // hardcoded dispersion. float constexpr kPsi = 1.0; @@ -313,7 +314,7 @@ struct EvalTweedieNLogLik { CHECK(rho_ < 2 && rho_ >= 1) << "tweedie variance power must be in interval [1, 2)"; } - const char *Name() const { + [[nodiscard]] const char *Name() const { static thread_local std::string name; std::ostringstream os; os << "tweedie-nloglik@" << rho_; @@ -321,7 +322,7 @@ struct EvalTweedieNLogLik { return name.c_str(); } - XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float p) const { + [[nodiscard]] XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float p) const { bst_float a = y * std::exp((1 - rho_) * std::log(p)) / (1 - rho_); bst_float b = std::exp((2 - rho_) * std::log(p)) / (2 - rho_); return -a + b; @@ -366,7 +367,8 @@ struct EvalEWiseBase : public MetricNoCache { }); std::array dat{result.Residue(), result.Weights()}; - collective::GlobalSum(info, &dat); + auto rc = collective::GlobalSum(ctx_, info, linalg::MakeVec(dat.data(), dat.size())); + collective::SafeColl(rc); return Policy::GetFinal(dat[0], dat[1]); } @@ -438,7 +440,8 @@ class QuantileError : public MetricNoCache { if (info.num_row_ == 0) { // empty DMatrix on distributed env std::array dat{0.0, 0.0}; - collective::GlobalSum(info, &dat); + auto rc = collective::GlobalSum(ctx_, info, linalg::MakeVec(dat.data(), dat.size())); + collective::SafeColl(rc); CHECK_GT(dat[1], 0); return dat[0] / dat[1]; } @@ -476,7 +479,8 @@ class QuantileError : public MetricNoCache { return std::make_tuple(l, w); }); std::array dat{result.Residue(), result.Weights()}; - collective::GlobalSum(info, &dat); + auto rc = collective::GlobalSum(ctx, info, linalg::MakeVec(dat.data(), dat.size())); + collective::SafeColl(rc); CHECK_GT(dat[1], 0); return dat[0] / dat[1]; } @@ -501,5 +505,4 @@ class QuantileError : public MetricNoCache { XGBOOST_REGISTER_METRIC(QuantileError, "quantile") .describe("Quantile regression error.") .set_body([](const char*) { return new QuantileError{}; }); -} // namespace metric -} // namespace xgboost +} // namespace xgboost::metric diff --git a/src/metric/metric_common.h b/src/metric/metric_common.h index 1b148ab0f..53c38ff2a 100644 --- a/src/metric/metric_common.h +++ b/src/metric/metric_common.h @@ -1,6 +1,5 @@ -/*! - * Copyright 2018-2022 by Contributors - * \file metric_common.h +/** + * Copyright 2018-2024, Contributors */ #ifndef XGBOOST_METRIC_METRIC_COMMON_H_ #define XGBOOST_METRIC_METRIC_COMMON_H_ @@ -24,7 +23,7 @@ class MetricNoCache : public Metric { double Evaluate(HostDeviceVector const &predts, std::shared_ptr p_fmat) final { double result{0.0}; auto const &info = p_fmat->Info(); - collective::ApplyWithLabels(info, &result, sizeof(double), + collective::ApplyWithLabels(ctx_, info, &result, sizeof(double), [&] { result = this->Eval(predts, info); }); return result; } diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index 897c91dab..acaef7cf7 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -1,5 +1,5 @@ /** - * Copyright 2015-2023 by XGBoost Contributors + * Copyright 2015-2024, XGBoost Contributors * \file multiclass_metric.cc * \brief evaluation metrics for multiclass classification. * \author Kailong Chen, Tianqi Chen @@ -24,8 +24,7 @@ #include "../common/device_helpers.cuh" #endif // XGBOOST_USE_CUDA -namespace xgboost { -namespace metric { +namespace xgboost::metric { // tag the this file, used by force static link later. DMLC_REGISTRY_FILE_TAG(multiclass_metric); @@ -40,11 +39,10 @@ class MultiClassMetricsReduction { public: MultiClassMetricsReduction() = default; - PackedReduceResult - CpuReduceMetrics(const HostDeviceVector &weights, - const HostDeviceVector &labels, - const HostDeviceVector &preds, - const size_t n_class, int32_t n_threads) const { + [[nodiscard]] PackedReduceResult CpuReduceMetrics(const HostDeviceVector& weights, + const HostDeviceVector& labels, + const HostDeviceVector& preds, + const size_t n_class, int32_t n_threads) const { size_t ndata = labels.Size(); const auto& h_labels = labels.HostVector(); @@ -182,7 +180,8 @@ struct EvalMClassBase : public MetricNoCache { dat[0] = result.Residue(); dat[1] = result.Weights(); } - collective::GlobalSum(info, &dat); + auto rc = collective::GlobalSum(ctx_, info, linalg::MakeVec(dat.data(), dat.size())); + collective::SafeColl(rc); return Derived::GetFinal(dat[0], dat[1]); } /*! @@ -245,5 +244,4 @@ XGBOOST_REGISTER_METRIC(MatchError, "merror") XGBOOST_REGISTER_METRIC(MultiLogLoss, "mlogloss") .describe("Multiclass negative loglikelihood.") .set_body([](const char*) { return new EvalMultiLogLoss(); }); -} // namespace metric -} // namespace xgboost +} // namespace xgboost::metric diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index 6762aec32..53841c051 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -101,7 +101,7 @@ struct EvalAMS : public MetricNoCache { } } - const char* Name() const override { + [[nodiscard]] const char* Name() const override { return name_.c_str(); } @@ -159,7 +159,7 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig { exc.Rethrow(); } - return collective::GlobalRatio(info, sum_metric, static_cast(ngroups)); + return collective::GlobalRatio(ctx_, info, sum_metric, static_cast(ngroups)); } [[nodiscard]] const char* Name() const override { @@ -274,7 +274,7 @@ class EvalRankWithCache : public Metric { double Evaluate(HostDeviceVector const& preds, std::shared_ptr p_fmat) override { double result{0.0}; auto const& info = p_fmat->Info(); - collective::ApplyWithLabels(info, &result, sizeof(double), [&] { + collective::ApplyWithLabels(ctx_, info, &result, sizeof(double), [&] { auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_); if (p_cache->Param() != param_) { p_cache = cache_.ResetItem(p_fmat, ctx_, info, param_); @@ -294,9 +294,10 @@ class EvalRankWithCache : public Metric { }; namespace { -double Finalize(Context const*, MetaInfo const& info, double score, double sw) { +double Finalize(Context const* ctx, MetaInfo const& info, double score, double sw) { std::array dat{score, sw}; - collective::GlobalSum(info, &dat); + auto rc = collective::GlobalSum(ctx, info, linalg::MakeVec(dat.data(), 2)); + collective::SafeColl(rc); std::tie(score, sw) = std::tuple_cat(dat); if (sw > 0.0) { score = score / sw; diff --git a/src/metric/rank_metric.cu b/src/metric/rank_metric.cu index 372eb6805..d43125dcb 100644 --- a/src/metric/rank_metric.cu +++ b/src/metric/rank_metric.cu @@ -1,5 +1,5 @@ /** - * Copyright 2020-2023 by XGBoost Contributors + * Copyright 2020-2024, XGBoost Contributors */ #include #include // for make_counting_iterator diff --git a/src/metric/survival_metric.cu b/src/metric/survival_metric.cu index c13702a19..c64fece6c 100644 --- a/src/metric/survival_metric.cu +++ b/src/metric/survival_metric.cu @@ -1,5 +1,5 @@ /** - * Copyright 2019-2023 by Contributors + * Copyright 2019-2024, Contributors * \file survival_metric.cu * \brief Metrics for survival analysis * \author Avinash Barnwal, Hyunsu Cho and Toby Hocking @@ -30,8 +30,7 @@ using ProbabilityDistributionType = xgboost::common::ProbabilityDistributionType template using AFTLoss = xgboost::common::AFTLoss; -namespace xgboost { -namespace metric { +namespace xgboost::metric { // tag the this file, used by force static link later. DMLC_REGISTRY_FILE_TAG(survival_metric); @@ -43,12 +42,11 @@ class ElementWiseSurvivalMetricsReduction { policy_ = policy; } - PackedReduceResult - CpuReduceMetrics(const HostDeviceVector &weights, - const HostDeviceVector &labels_lower_bound, - const HostDeviceVector &labels_upper_bound, - const HostDeviceVector &preds, - int32_t n_threads) const { + [[nodiscard]] PackedReduceResult CpuReduceMetrics( + const HostDeviceVector& weights, + const HostDeviceVector& labels_lower_bound, + const HostDeviceVector& labels_upper_bound, + const HostDeviceVector& preds, int32_t n_threads) const { size_t ndata = labels_lower_bound.Size(); CHECK_EQ(ndata, labels_upper_bound.Size()); @@ -155,7 +153,7 @@ class ElementWiseSurvivalMetricsReduction { struct EvalIntervalRegressionAccuracy { void Configure(const Args&) {} - const char* Name() const { + [[nodiscard]] const char* Name() const { return "interval-regression-accuracy"; } @@ -177,7 +175,7 @@ struct EvalAFTNLogLik { param_.UpdateAllowUnknown(args); } - const char* Name() const { + [[nodiscard]] const char* Name() const { return "aft-nloglik"; } @@ -213,7 +211,8 @@ struct EvalEWiseSurvivalBase : public MetricNoCache { info.labels_upper_bound_, preds); std::array dat{result.Residue(), result.Weights()}; - collective::GlobalSum(info, &dat); + auto rc = collective::GlobalSum(ctx_, info, linalg::MakeVec(dat.data(), dat.size())); + collective::SafeColl(rc); return Policy::GetFinal(dat[0], dat[1]); } @@ -230,7 +229,7 @@ struct EvalEWiseSurvivalBase : public MetricNoCache { // This class exists because we want to perform dispatch according to the distribution type at // configuration time, not at prediction time. struct AFTNLogLikDispatcher : public MetricNoCache { - const char* Name() const override { + [[nodiscard]] const char* Name() const override { return "aft-nloglik"; } @@ -282,5 +281,4 @@ XGBOOST_REGISTER_METRIC(IntervalRegressionAccuracy, "interval-regression-accurac return new EvalEWiseSurvivalBase(); }); -} // namespace metric -} // namespace xgboost +} // namespace xgboost::metric diff --git a/src/objective/adaptive.cc b/src/objective/adaptive.cc index b195dffd7..e7778c464 100644 --- a/src/objective/adaptive.cc +++ b/src/objective/adaptive.cc @@ -1,5 +1,5 @@ /** - * Copyright 2022-2023 by XGBoost Contributors + * Copyright 2022-2024, XGBoost Contributors */ #include "adaptive.h" @@ -85,7 +85,7 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector const& posit size_t n_leaf = nidx.size(); if (nptr.empty()) { std::vector quantiles; - UpdateLeafValues(&quantiles, nidx, info, learning_rate, p_tree); + UpdateLeafValues(ctx, &quantiles, nidx, info, learning_rate, p_tree); return; } @@ -100,7 +100,7 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector const& posit predt.Size() / info.num_row_); collective::ApplyWithLabels( - info, static_cast(quantiles.data()), quantiles.size() * sizeof(float), [&] { + ctx, info, static_cast(quantiles.data()), quantiles.size() * sizeof(float), [&] { // loop over each leaf common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) { auto nidx = h_node_idx[k]; @@ -134,7 +134,7 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector const& posit }); }); - UpdateLeafValues(&quantiles, nidx, info, learning_rate, p_tree); + UpdateLeafValues(ctx, &quantiles, nidx, info, learning_rate, p_tree); } #if !defined(XGBOOST_USE_CUDA) diff --git a/src/objective/adaptive.cu b/src/objective/adaptive.cu index 07644146b..235e28419 100644 --- a/src/objective/adaptive.cu +++ b/src/objective/adaptive.cu @@ -1,5 +1,5 @@ /** - * Copyright 2022-2023 by XGBoost Contributors + * Copyright 2022-2024, XGBoost Contributors */ #include @@ -150,7 +150,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span pos if (nptr.Empty()) { std::vector quantiles; - UpdateLeafValues(&quantiles, nidx.ConstHostVector(), info, learning_rate, p_tree); + UpdateLeafValues(ctx, &quantiles, nidx.ConstHostVector(), info, learning_rate, p_tree); } predt.SetDevice(ctx->Device()); @@ -160,7 +160,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span pos auto t_predt = d_predt.Slice(linalg::All(), group_idx); HostDeviceVector quantiles; - collective::ApplyWithLabels(info, &quantiles, [&] { + collective::ApplyWithLabels(ctx, info, &quantiles, [&] { auto d_labels = info.labels.View(ctx->Device()).Slice(linalg::All(), IdxY(info, group_idx)); auto d_row_index = dh::ToSpan(ridx); auto seg_beg = nptr.DevicePointer(); @@ -186,6 +186,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span pos w_it + d_weights.size(), &quantiles); } }); - UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), info, learning_rate, p_tree); + UpdateLeafValues(ctx, &quantiles.HostVector(), nidx.ConstHostVector(), info, learning_rate, + p_tree); } } // namespace xgboost::obj::detail diff --git a/src/objective/adaptive.h b/src/objective/adaptive.h index a64f37f63..cbe69e79a 100644 --- a/src/objective/adaptive.h +++ b/src/objective/adaptive.h @@ -1,5 +1,5 @@ /** - * Copyright 2022-2023 by XGBoost Contributors + * Copyright 2022-2024, XGBoost Contributors */ #pragma once @@ -17,8 +17,7 @@ #include "xgboost/host_device_vector.h" // HostDeviceVector #include "xgboost/tree_model.h" // RegTree -namespace xgboost { -namespace obj { +namespace xgboost::obj { namespace detail { inline void FillMissingLeaf(std::vector const& maybe_missing, std::vector* p_nidx, std::vector* p_nptr) { @@ -36,13 +35,14 @@ inline void FillMissingLeaf(std::vector const& maybe_missing, } } -inline void UpdateLeafValues(std::vector* p_quantiles, std::vector const& nidx, - MetaInfo const& info, float learning_rate, RegTree* p_tree) { +inline void UpdateLeafValues(Context const* ctx, std::vector* p_quantiles, + std::vector const& nidx, MetaInfo const& info, + float learning_rate, RegTree* p_tree) { auto& tree = *p_tree; auto& quantiles = *p_quantiles; auto const& h_node_idx = nidx; - size_t n_leaf = collective::GlobalMax(info, h_node_idx.size()); + size_t n_leaf = collective::GlobalMax(ctx, info, h_node_idx.size()); CHECK(quantiles.empty() || quantiles.size() == n_leaf); if (quantiles.empty()) { quantiles.resize(n_leaf, std::numeric_limits::quiet_NaN()); @@ -52,12 +52,16 @@ inline void UpdateLeafValues(std::vector* p_quantiles, std::vector n_valids(quantiles.size()); std::transform(quantiles.cbegin(), quantiles.cend(), n_valids.begin(), [](float q) { return static_cast(!std::isnan(q)); }); - collective::GlobalSum(info, &n_valids); + auto rc = collective::GlobalSum(ctx, info, linalg::MakeVec(n_valids.data(), n_valids.size())); + collective::SafeColl(rc); + // convert to 0 for all reduce std::replace_if( quantiles.begin(), quantiles.end(), [](float q) { return std::isnan(q); }, 0.f); // use the mean value - collective::GlobalSum(info, &quantiles); + rc = collective::GlobalSum(ctx, info, linalg::MakeVec(quantiles.data(), quantiles.size())); + collective::SafeColl(rc); + for (size_t i = 0; i < n_leaf; ++i) { if (n_valids[i] > 0) { quantiles[i] /= static_cast(n_valids[i]); @@ -105,5 +109,4 @@ inline void UpdateTreeLeaf(Context const* ctx, HostDeviceVector cons predt, alpha, p_tree); } } -} // namespace obj -} // namespace xgboost +} // namespace xgboost::obj diff --git a/src/objective/quantile_obj.cu b/src/objective/quantile_obj.cu index 15ec72f95..7029a201a 100644 --- a/src/objective/quantile_obj.cu +++ b/src/objective/quantile_obj.cu @@ -1,5 +1,5 @@ /** - * Copyright 2023 by XGBoost contributors + * Copyright 2023-2024, XGBoost contributors */ #include // std::array #include // std::size_t @@ -170,7 +170,9 @@ class QuantileRegression : public ObjFunction { double meanq = temp(0) * sw; std::array dat{meanq, sw}; - collective::GlobalSum(info, &dat); + auto rc = collective::GlobalSum(ctx_, info, linalg::MakeVec(dat.data(), dat.size())); + collective::SafeColl(rc); + std::tie(meanq, sw) = std::tuple_cat(dat); meanq /= (sw + kRtEps); base_score->Reshape(1); diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index df30b354b..3b60ff111 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -1,5 +1,5 @@ /** - * Copyright 2015-2023 by XGBoost Contributors + * Copyright 2015-2024, XGBoost Contributors * \file regression_obj.cu * \brief Definition of single-value regression and classification objectives. * \author Tianqi Chen, Kailong Chen @@ -672,8 +672,12 @@ class MeanAbsoluteError : public ObjFunction { std::transform(linalg::cbegin(out), linalg::cend(out), linalg::begin(out), [w](float v) { return v * w; }); - collective::GlobalSum(info, &out.Values()); - collective::GlobalSum(info, &w, 1); + auto rc = collective::Success() << [&] { + return collective::GlobalSum(ctx_, info, out); + } << [&] { + return collective::GlobalSum(ctx_, info, linalg::MakeVec(&w, 1)); + }; + collective::SafeColl(rc); if (common::CloseTo(w, 0.0)) { // Mostly for handling empty dataset test. diff --git a/src/tree/fit_stump.cc b/src/tree/fit_stump.cc index 21a050536..5e4d16e4e 100644 --- a/src/tree/fit_stump.cc +++ b/src/tree/fit_stump.cc @@ -1,7 +1,7 @@ /** - * Copyright 2022 by XGBoost Contributors + * Copyright 2022-2024, XGBoost Contributors * - * \brief Utilities for estimating initial score. + * @brief Utilities for estimating initial score. */ #include "fit_stump.h" @@ -44,8 +44,11 @@ void FitStump(Context const* ctx, MetaInfo const& info, } } CHECK(h_sum.CContiguous()); - - collective::GlobalSum(info, reinterpret_cast(h_sum.Values().data()), h_sum.Size() * 2); + auto as_double = linalg::MakeTensorView( + ctx, common::Span{reinterpret_cast(h_sum.Values().data()), h_sum.Size() * 2}, + h_sum.Size() * 2); + auto rc = collective::GlobalSum(ctx, info, as_double); + collective::SafeColl(rc); for (std::size_t i = 0; i < h_sum.Size(); ++i) { out(i) = static_cast(CalcUnregularizedWeight(h_sum(i).GetGrad(), h_sum(i).GetHess())); diff --git a/src/tree/fit_stump.cu b/src/tree/fit_stump.cu index 9fcacd081..832d40754 100644 --- a/src/tree/fit_stump.cu +++ b/src/tree/fit_stump.cu @@ -1,19 +1,18 @@ /** - * Copyright 2022-2023 by XGBoost Contributors + * Copyright 2022-2024, XGBoost Contributors * - * \brief Utilities for estimating initial score. + * @brief Utilities for estimating initial score. */ #if !defined(NOMINMAX) && defined(_WIN32) #define NOMINMAX -#endif // !defined(NOMINMAX) -#include // cuda::par -#include // thrust::make_counting_iterator +#endif // !defined(NOMINMAX) +#include // cuda::par +#include // thrust::make_counting_iterator -#include // std::size_t +#include // std::size_t -#include "../collective/aggregator.cuh" -#include "../collective/communicator-inl.cuh" -#include "../common/device_helpers.cuh" // dh::MakeTransformIterator +#include "../collective/aggregator.cuh" // for GlobalSum +#include "../common/device_helpers.cuh" // dh::MakeTransformIterator #include "fit_stump.h" #include "xgboost/base.h" // GradientPairPrecise, GradientPair, XGBOOST_DEVICE #include "xgboost/context.h" // Context diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index c473c9269..90c151556 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -1,5 +1,5 @@ /** - * Copyright 2020-2023 by XGBoost Contributors + * Copyright 2020-2024, XGBoost Contributors */ #include #include @@ -52,7 +52,7 @@ struct Clip : public thrust::unary_function { * * to avoid outliers, as the full reduction is reproducible on GPU with reduction tree. */ -GradientQuantiser::GradientQuantiser(Context const*, common::Span gpair, +GradientQuantiser::GradientQuantiser(Context const* ctx, common::Span gpair, MetaInfo const& info) { using GradientSumT = GradientPairPrecise; using T = typename GradientSumT::ValueT; @@ -65,11 +65,14 @@ GradientQuantiser::GradientQuantiser(Context const*, common::Span(&p), 4); + auto rc = collective::GlobalSum(ctx, info, linalg::MakeVec(reinterpret_cast(&p), 4)); + collective::SafeColl(rc); + GradientPair positive_sum{p.first}, negative_sum{p.second}; std::size_t total_rows = gpair.size(); - collective::GlobalSum(info, &total_rows, 1); + rc = collective::GlobalSum(ctx, info, linalg::MakeVec(&total_rows, 1)); + collective::SafeColl(rc); auto histogram_rounding = GradientSumT{common::CreateRoundingFactor( diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 94e7547ee..68317fc41 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021-2023 by XGBoost contributors + * Copyright 2021-2024, XGBoost contributors * * \brief Implementation for the approx tree method. */ @@ -107,7 +107,10 @@ class GloablApproxBuilder { for (auto const &g : gpair) { root_sum.Add(g); } - collective::GlobalSum(p_fmat->Info(), reinterpret_cast(&root_sum), 2); + auto rc = collective::GlobalSum(ctx_, p_fmat->Info(), + linalg::MakeVec(reinterpret_cast(&root_sum), 2)); + collective::SafeColl(rc); + std::vector nodes{best}; this->histogram_builder_.BuildRootHist(p_fmat, p_tree, partitioner_, linalg::MakeTensorView(ctx_, gpair, gpair.size(), 1), diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 3c9c61f88..4911cec09 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -1,5 +1,5 @@ /** - * Copyright 2017-2023 by XGBoost contributors + * Copyright 2017-2024, XGBoost contributors */ #include #include @@ -729,7 +729,9 @@ struct GPUHistMakerDevice { dh::Reduce(ctx_->CUDACtx()->CTP(), gpair_it, gpair_it + gpair.size(), GradientPairInt64{}, thrust::plus{}); using ReduceT = typename decltype(root_sum_quantised)::ValueT; - collective::GlobalSum(info_, reinterpret_cast(&root_sum_quantised), 2); + auto rc = collective::GlobalSum( + ctx_, info_, linalg::MakeVec(reinterpret_cast(&root_sum_quantised), 2)); + collective::SafeColl(rc); hist.AllocateHistograms({kRootNIdx}); this->BuildHist(kRootNIdx); diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index cd60e6602..ced277773 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -199,8 +199,10 @@ class MultiTargetHistBuilder { } } CHECK(root_sum.CContiguous()); - collective::GlobalSum(p_fmat->Info(), reinterpret_cast(root_sum.Values().data()), - root_sum.Size() * 2); + auto rc = collective::GlobalSum( + ctx_, p_fmat->Info(), + linalg::MakeVec(reinterpret_cast(root_sum.Values().data()), root_sum.Size() * 2)); + collective::SafeColl(rc); histogram_builder_->BuildRootHist(p_fmat, p_tree, partitioner_, gpair, best, HistBatch(param_)); @@ -408,7 +410,9 @@ class HistUpdater { for (auto const &grad : gpair_h) { grad_stat.Add(grad.GetGrad(), grad.GetHess()); } - collective::GlobalSum(p_fmat->Info(), reinterpret_cast(&grad_stat), 2); + auto rc = collective::GlobalSum(ctx_, p_fmat->Info(), + linalg::MakeVec(reinterpret_cast(&grad_stat), 2)); + collective::SafeColl(rc); } auto weight = evaluator_->InitRoot(GradStats{grad_stat}); @@ -471,6 +475,7 @@ class QuantileHistMaker : public TreeUpdater { std::unique_ptr p_impl_{nullptr}; std::unique_ptr p_mtimpl_{nullptr}; std::shared_ptr column_sampler_; + common::Monitor monitor_; ObjInfo const *task_{nullptr}; HistMakerTrainParam hist_param_;