diff --git a/src/collective/aggregator.h b/src/collective/aggregator.h index 12222cf9d..fe7b65930 100644 --- a/src/collective/aggregator.h +++ b/src/collective/aggregator.h @@ -8,6 +8,7 @@ #pragma once #include +#include #include #include #include @@ -57,5 +58,72 @@ void ApplyWithLabels(MetaInfo const& info, T* buffer, size_t size, Function&& fu std::forward(function)(std::forward(args)...); } } + +/** + * @brief Find the global max of the given value across all workers. + * + * This only applies when the data is split row-wise (horizontally). When data is split + * column-wise (vertically), the local value is returned. + * + * @tparam T The type of the value. + * @param info MetaInfo about the DMatrix. + * @param value The input for finding the global max. + * @return The global max of the input. + */ +template +T GlobalMax(MetaInfo const& info, T value) { + if (info.IsRowSplit()) { + collective::Allreduce(&value, 1); + } + return value; +} + +/** + * @brief Find the global sum of the given values across all workers. + * + * This only applies when the data is split row-wise (horizontally). When data is split + * column-wise (vertically), the original values are returned. + * + * @tparam T The type of the values. + * @param info MetaInfo about the DMatrix. + * @param values Pointer to the inputs to sum. + * @param size Number of values to sum. + */ +template +void GlobalSum(MetaInfo const& info, T* values, size_t size) { + if (info.IsRowSplit()) { + collective::Allreduce(values, size); + } +} + +template +void GlobalSum(MetaInfo const& info, Container* values) { + GlobalSum(info, values->data(), values->size()); +} + +/** + * @brief Find the global ratio of the given two values across all workers. + * + * This only applies when the data is split row-wise (horizontally). When data is split + * column-wise (vertically), the local ratio is returned. + * + * @tparam T The type of the values. + * @param info MetaInfo about the DMatrix. + * @param dividend The dividend of the ratio. + * @param divisor The divisor of the ratio. + * @return The global ratio of the two inputs. + */ +template +T GlobalRatio(MetaInfo const& info, T dividend, T divisor) { + std::array results{dividend, divisor}; + GlobalSum(info, &results); + std::tie(dividend, divisor) = std::tuple_cat(results); + if (divisor <= 0) { + return std::numeric_limits::quiet_NaN(); + } else { + return dividend / divisor; + } +} + } // namespace collective } // namespace xgboost diff --git a/src/metric/auc.cc b/src/metric/auc.cc index bde3127ed..473f5b02c 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -116,10 +116,7 @@ double MultiClassOVR(Context const *ctx, common::Span predts, MetaI // we have 2 averages going in here, first is among workers, second is among // classes. allreduce sums up fp/tp auc for each class. - if (info.IsRowSplit()) { - collective::Allreduce(results.Values().data(), - results.Values().size()); - } + collective::GlobalSum(info, &results.Values()); double auc_sum{0}; double tp_sum{0}; for (size_t c = 0; c < n_classes; ++c) { @@ -293,17 +290,8 @@ class EvalAUC : public MetricNoCache { InvalidGroupAUC(); } - std::array results{auc, static_cast(valid_groups)}; - if (info.IsRowSplit()) { - collective::Allreduce(results.data(), results.size()); - } - auc = results[0]; - valid_groups = static_cast(results[1]); - - if (valid_groups <= 0) { - auc = std::numeric_limits::quiet_NaN(); - } else { - auc /= valid_groups; + auc = collective::GlobalRatio(info, auc, static_cast(valid_groups)); + if (!std::isnan(auc)) { CHECK_LE(auc, 1) << "Total AUC across groups: " << auc * valid_groups << ", valid groups: " << valid_groups; } @@ -323,19 +311,9 @@ class EvalAUC : public MetricNoCache { std::tie(fp, tp, auc) = static_cast(this)->EvalBinary(preds, info); } - double local_area = fp * tp; - std::array result{auc, local_area}; - if (info.IsRowSplit()) { - collective::Allreduce(result.data(), result.size()); - } - std::tie(auc, local_area) = common::UnpackArr(std::move(result)); - if (local_area <= 0) { - // the dataset across all workers have only positive or negative sample - auc = std::numeric_limits::quiet_NaN(); - } else { - CHECK_LE(auc, local_area); - // normalization - auc = auc / local_area; + auc = collective::GlobalRatio(info, auc, fp * tp); + if (!std::isnan(auc)) { + CHECK_LE(auc, 1.0); } } if (std::isnan(auc)) { diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index 01aec16e1..bd1b0b2d8 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -8,6 +8,7 @@ */ #include +#include #include #include "../collective/communicator-inl.h" @@ -197,10 +198,8 @@ class PseudoErrorLoss : public MetricNoCache { auto v = common::Sqr(slope) * (std::sqrt((1 + common::Sqr(a / slope))) - 1) * wt; return std::make_tuple(v, wt); }); - double dat[2]{result.Residue(), result.Weights()}; - if (info.IsRowSplit()) { - collective::Allreduce(dat, 2); - } + std::array dat{result.Residue(), result.Weights()}; + collective::GlobalSum(info, &dat); return EvalRowMAPE::GetFinal(dat[0], dat[1]); } }; @@ -366,10 +365,8 @@ struct EvalEWiseBase : public MetricNoCache { return std::make_tuple(residue, wt); }); - double dat[2]{result.Residue(), result.Weights()}; - if (info.IsRowSplit()) { - collective::Allreduce(dat, 2); - } + std::array dat{result.Residue(), result.Weights()}; + collective::GlobalSum(info, &dat); return Policy::GetFinal(dat[0], dat[1]); } @@ -440,10 +437,8 @@ class QuantileError : public MetricNoCache { CHECK(!alpha_.Empty()); if (info.num_row_ == 0) { // empty DMatrix on distributed env - double dat[2]{0.0, 0.0}; - if (info.IsRowSplit()) { - collective::Allreduce(dat, 2); - } + std::array dat{0.0, 0.0}; + collective::GlobalSum(info, &dat); CHECK_GT(dat[1], 0); return dat[0] / dat[1]; } @@ -480,10 +475,8 @@ class QuantileError : public MetricNoCache { loss(y_predt(sample_id, quantile_id, target_id), y_true(sample_id, target_id)) * w; return std::make_tuple(l, w); }); - double dat[2]{result.Residue(), result.Weights()}; - if (info.IsRowSplit()) { - collective::Allreduce(dat, 2); - } + std::array dat{result.Residue(), result.Weights()}; + collective::GlobalSum(info, &dat); CHECK_GT(dat[1], 0); return dat[0] / dat[1]; } diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index a1d19dbc8..f6f3f3d04 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -6,6 +6,7 @@ */ #include +#include #include #include @@ -169,7 +170,7 @@ struct EvalMClassBase : public MetricNoCache { } else { CHECK(preds.Size() % info.labels.Size() == 0) << "label and prediction size not match"; } - double dat[2] { 0.0, 0.0 }; + std::array dat{0.0, 0.0}; if (info.labels.Size() != 0) { const size_t nclass = preds.Size() / info.labels.Size(); CHECK_GE(nclass, 1U) @@ -181,9 +182,7 @@ struct EvalMClassBase : public MetricNoCache { dat[0] = result.Residue(); dat[1] = result.Weights(); } - if (info.IsRowSplit()) { - collective::Allreduce(dat, 2); - } + collective::GlobalSum(info, &dat); return Derived::GetFinal(dat[0], dat[1]); } /*! diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index 000b88e80..4f272e939 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -238,14 +238,7 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig { exc.Rethrow(); } - if (collective::IsDistributed() && info.IsRowSplit()) { - double dat[2]{sum_metric, static_cast(ngroups)}; - // approximately estimate the metric using mean - collective::Allreduce(dat, 2); - return dat[0] / dat[1]; - } else { - return sum_metric / ngroups; - } + return collective::GlobalRatio(info, sum_metric, static_cast(ngroups)); } const char* Name() const override { @@ -401,9 +394,8 @@ class EvalRankWithCache : public Metric { namespace { double Finalize(MetaInfo const& info, double score, double sw) { std::array dat{score, sw}; - if (info.IsRowSplit()) { - collective::Allreduce(dat.data(), dat.size()); - } + collective::GlobalSum(info, &dat); + std::tie(score, sw) = std::tuple_cat(dat); if (sw > 0.0) { score = score / sw; } diff --git a/src/metric/survival_metric.cu b/src/metric/survival_metric.cu index 9b1773dc5..5f8c8ee6a 100644 --- a/src/metric/survival_metric.cu +++ b/src/metric/survival_metric.cu @@ -7,6 +7,7 @@ #include +#include #include #include @@ -211,10 +212,8 @@ struct EvalEWiseSurvivalBase : public MetricNoCache { auto result = reducer_.Reduce(*ctx_, info.weights_, info.labels_lower_bound_, info.labels_upper_bound_, preds); - double dat[2]{result.Residue(), result.Weights()}; - if (info.IsRowSplit()) { - collective::Allreduce(dat, 2); - } + std::array dat{result.Residue(), result.Weights()}; + collective::GlobalSum(info, &dat); return Policy::GetFinal(dat[0], dat[1]); } diff --git a/src/objective/adaptive.h b/src/objective/adaptive.h index 7494bceb1..ffd3ddec7 100644 --- a/src/objective/adaptive.h +++ b/src/objective/adaptive.h @@ -6,8 +6,9 @@ #include #include // std::int32_t #include -#include // std::vector +#include // std::vector +#include "../collective/aggregator.h" #include "../collective/communicator-inl.h" #include "../common/common.h" #include "xgboost/base.h" // bst_node_t @@ -41,10 +42,7 @@ inline void UpdateLeafValues(std::vector* p_quantiles, std::vector(&n_leaf, 1); - } + size_t n_leaf = collective::GlobalMax(info, h_node_idx.size()); CHECK(quantiles.empty() || quantiles.size() == n_leaf); if (quantiles.empty()) { quantiles.resize(n_leaf, std::numeric_limits::quiet_NaN()); @@ -54,16 +52,12 @@ inline void UpdateLeafValues(std::vector* p_quantiles, std::vector n_valids(quantiles.size()); std::transform(quantiles.cbegin(), quantiles.cend(), n_valids.begin(), [](float q) { return static_cast(!std::isnan(q)); }); - if (info.IsRowSplit()) { - collective::Allreduce(n_valids.data(), n_valids.size()); - } + collective::GlobalSum(info, &n_valids); // convert to 0 for all reduce std::replace_if( quantiles.begin(), quantiles.end(), [](float q) { return std::isnan(q); }, 0.f); // use the mean value - if (info.IsRowSplit()) { - collective::Allreduce(quantiles.data(), quantiles.size()); - } + collective::GlobalSum(info, &quantiles); for (size_t i = 0; i < n_leaf; ++i) { if (n_valids[i] > 0) { quantiles[i] /= static_cast(n_valids[i]); diff --git a/src/objective/quantile_obj.cu b/src/objective/quantile_obj.cu index b6e540b24..b34f37ff9 100644 --- a/src/objective/quantile_obj.cu +++ b/src/objective/quantile_obj.cu @@ -1,6 +1,7 @@ /** * Copyright 2023 by XGBoost contributors */ +#include // std::array #include // std::size_t #include // std::int32_t #include // std::vector @@ -170,10 +171,9 @@ class QuantileRegression : public ObjFunction { common::Mean(ctx_, *base_score, &temp); double meanq = temp(0) * sw; - if (info.IsRowSplit()) { - collective::Allreduce(&meanq, 1); - collective::Allreduce(&sw, 1); - } + std::array dat{meanq, sw}; + collective::GlobalSum(info, &dat); + std::tie(meanq, sw) = std::tuple_cat(dat); meanq /= (sw + kRtEps); base_score->Reshape(1); base_score->Data()->Fill(meanq); diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index e0dbb2edc..4c5ed9ec8 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -728,10 +728,8 @@ class MeanAbsoluteError : public ObjFunction { std::transform(linalg::cbegin(out), linalg::cend(out), linalg::begin(out), [w](float v) { return v * w; }); - if (info.IsRowSplit()) { - collective::Allreduce(out.Values().data(), out.Values().size()); - collective::Allreduce(&w, 1); - } + collective::GlobalSum(info, &out.Values()); + collective::GlobalSum(info, &w, 1); if (common::CloseTo(w, 0.0)) { // Mostly for handling empty dataset test. diff --git a/src/tree/fit_stump.cc b/src/tree/fit_stump.cc index 55f23b329..3533de772 100644 --- a/src/tree/fit_stump.cc +++ b/src/tree/fit_stump.cc @@ -8,6 +8,7 @@ #include // std::int32_t #include // std::size_t +#include "../collective/aggregator.h" #include "../collective/communicator-inl.h" #include "../common/common.h" // AssertGPUSupport #include "../common/numeric.h" // cpu_impl::Reduce @@ -45,10 +46,7 @@ void FitStump(Context const* ctx, MetaInfo const& info, } CHECK(h_sum.CContiguous()); - if (info.IsRowSplit()) { - collective::Allreduce( - reinterpret_cast(h_sum.Values().data()), h_sum.Size() * 2); - } + collective::GlobalSum(info, reinterpret_cast(h_sum.Values().data()), h_sum.Size() * 2); for (std::size_t i = 0; i < h_sum.Size(); ++i) { out(i) = static_cast(CalcUnregularizedWeight(h_sum(i).GetGrad(), h_sum(i).GetHess())); diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index d22e8f679..148614a7e 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -7,6 +7,7 @@ #include #include +#include "../collective/aggregator.h" #include "../common/random.h" #include "../data/gradient_index.h" #include "common_row_partitioner.h" @@ -92,9 +93,7 @@ class GloablApproxBuilder { for (auto const &g : gpair) { root_sum.Add(g); } - if (p_fmat->Info().IsRowSplit()) { - collective::Allreduce(reinterpret_cast(&root_sum), 2); - } + collective::GlobalSum(p_fmat->Info(), reinterpret_cast(&root_sum), 2); std::vector nodes{best}; size_t i = 0; auto space = ConstructHistSpace(partitioner_, nodes);