Require context in aggregators. (#10075)

This commit is contained in:
Jiaming Yuan 2024-02-28 03:12:42 +08:00 committed by GitHub
parent 761845f594
commit 5ac233280e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 190 additions and 144 deletions

View File

@ -17,7 +17,7 @@ AllowShortEnumsOnASingleLine: true
AllowShortBlocksOnASingleLine: Never
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: All
AllowShortLambdasOnASingleLine: All
AllowShortLambdasOnASingleLine: Inline
AllowShortIfStatementsOnASingleLine: WithoutElse
AllowShortLoopsOnASingleLine: true
AlwaysBreakAfterDefinitionReturnType: None

View File

@ -1,8 +1,10 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once
#include <xgboost/logging.h>
#include <memory> // for unique_ptr
#include <sstream> // for stringstream
#include <stack> // for stack
@ -160,10 +162,16 @@ struct Result {
// We don't have monad, a simple helper would do.
template <typename Fn>
Result operator<<(Result&& r, Fn&& fn) {
[[nodiscard]] Result operator<<(Result&& r, Fn&& fn) {
if (!r.OK()) {
return std::forward<Result>(r);
}
return fn();
}
inline void SafeColl(Result const& rc) {
if (!rc.OK()) {
LOG(FATAL) << rc.Report();
}
}
} // namespace xgboost::collective

View File

@ -1,22 +1,21 @@
/**
* Copyright 2023 by XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*
* Higher level functions built on top the Communicator API, taking care of behavioral differences
* between row-split vs column-split distributed training, and horizontal vs vertical federated
* learning.
*/
#pragma once
#include <xgboost/data.h>
#include <limits>
#include <string>
#include <utility>
#include <vector>
#include "communicator-inl.h"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/data.h" // for MetaINfo
namespace xgboost {
namespace collective {
namespace xgboost::collective {
/**
* @brief Apply the given function where the labels are.
@ -31,15 +30,16 @@ namespace collective {
* @param size The size of the buffer.
* @param function The function used to calculate the results.
*/
template <typename Function>
void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&& function) {
template <typename FN>
void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::size_t size,
FN&& function) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
// broadcast to other workers.
std::string message;
if (collective::GetRank() == 0) {
try {
std::forward<Function>(function)();
std::forward<FN>(function)();
} catch (dmlc::Error& e) {
message = e.what();
}
@ -52,7 +52,7 @@ void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&&
LOG(FATAL) << &message[0];
}
} else {
std::forward<Function>(function)();
std::forward<FN>(function)();
}
}
@ -70,7 +70,8 @@ void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&&
* @param function The function used to calculate the results.
*/
template <typename T, typename Function>
void ApplyWithLabels(MetaInfo const& info, HostDeviceVector<T>* result, Function&& function) {
void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>* result,
Function&& function) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
// broadcast to other workers.
@ -114,7 +115,9 @@ void ApplyWithLabels(MetaInfo const& info, HostDeviceVector<T>* result, Function
* @return The global max of the input.
*/
template <typename T>
T GlobalMax(MetaInfo const& info, T value) {
std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context const*,
MetaInfo const& info,
T value) {
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kMax>(&value, 1);
}
@ -132,16 +135,18 @@ T GlobalMax(MetaInfo const& info, T value) {
* @param values Pointer to the inputs to sum.
* @param size Number of values to sum.
*/
template <typename T>
void GlobalSum(MetaInfo const& info, T* values, size_t size) {
template <typename T, std::int32_t kDim>
[[nodiscard]] Result GlobalSum(Context const*, MetaInfo const& info,
linalg::TensorView<T, kDim> values) {
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(values, size);
collective::Allreduce<collective::Operation::kSum>(values.Values().data(), values.Size());
}
return Success();
}
template <typename Container>
void GlobalSum(MetaInfo const& info, Container* values) {
GlobalSum(info, values->data(), values->size());
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info, Container* values) {
return GlobalSum(ctx, info, values->data(), values->size());
}
/**
@ -157,9 +162,10 @@ void GlobalSum(MetaInfo const& info, Container* values) {
* @return The global ratio of the two inputs.
*/
template <typename T>
T GlobalRatio(MetaInfo const& info, T dividend, T divisor) {
T GlobalRatio(Context const* ctx, MetaInfo const& info, T dividend, T divisor) {
std::array<T, 2> results{dividend, divisor};
GlobalSum(info, &results);
auto rc = GlobalSum(ctx, info, linalg::MakeVec(results.data(), results.size()));
collective::SafeColl(rc);
std::tie(dividend, divisor) = std::tuple_cat(results);
if (divisor <= 0) {
return std::numeric_limits<T>::quiet_NaN();
@ -167,6 +173,4 @@ T GlobalRatio(MetaInfo const& info, T dividend, T divisor) {
return dividend / divisor;
}
}
} // namespace collective
} // namespace xgboost
} // namespace xgboost::collective

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2020-2022 by XGBoost Contributors
/**
* Copyright 2020-2024, XGBoost Contributors
*/
#include "quantile.h"
@ -145,7 +145,7 @@ struct QuantileAllreduce {
template <typename WQSketch>
void SketchContainerImpl<WQSketch>::GatherSketchInfo(
Context const *, MetaInfo const &info,
Context const *ctx, MetaInfo const &info,
std::vector<typename WQSketch::SummaryContainer> const &reduced,
std::vector<size_t> *p_worker_segments, std::vector<bst_row_t> *p_sketches_scan,
std::vector<typename WQSketch::Entry> *p_global_sketches) {
@ -171,7 +171,9 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
std::partial_sum(sketch_size.cbegin(), sketch_size.cend(), sketches_scan.begin() + beg_scan + 1);
// Gather all column pointers
collective::GlobalSum(info, sketches_scan.data(), sketches_scan.size());
auto rc =
collective::GlobalSum(ctx, info, linalg::MakeVec(sketches_scan.data(), sketches_scan.size()));
collective::SafeColl(rc);
for (int32_t i = 0; i < world; ++i) {
size_t back = (i + 1) * (n_columns + 1) - 1;
auto n_entries = sketches_scan.at(back);
@ -199,14 +201,15 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
static_assert(sizeof(typename WQSketch::Entry) / 4 == sizeof(float),
"Unexpected size of sketch entry.");
collective::GlobalSum(
info,
reinterpret_cast<float *>(global_sketches.data()),
global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float));
rc = collective::GlobalSum(
ctx, info,
linalg::MakeVec(reinterpret_cast<float *>(global_sketches.data()),
global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float)));
collective::SafeColl(rc);
}
template <typename WQSketch>
void SketchContainerImpl<WQSketch>::AllreduceCategories(Context const*, MetaInfo const& info) {
void SketchContainerImpl<WQSketch>::AllreduceCategories(Context const* ctx, MetaInfo const& info) {
auto world_size = collective::GetWorldSize();
auto rank = collective::GetRank();
if (world_size == 1 || info.IsColumnSplit()) {
@ -226,7 +229,8 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(Context const*, MetaInfo
std::vector<size_t> global_feat_ptrs(feature_ptr.size() * world_size, 0);
size_t feat_begin = rank * feature_ptr.size(); // pointer to current worker
std::copy(feature_ptr.begin(), feature_ptr.end(), global_feat_ptrs.begin() + feat_begin);
collective::GlobalSum(info, global_feat_ptrs.data(), global_feat_ptrs.size());
auto rc = collective::GlobalSum(
ctx, info, linalg::MakeVec(global_feat_ptrs.data(), global_feat_ptrs.size()));
// move all categories into a flatten vector to prepare for allreduce
size_t total = feature_ptr.back();
@ -239,7 +243,8 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(Context const*, MetaInfo
// indptr for indexing workers
std::vector<size_t> global_worker_ptr(world_size + 1, 0);
global_worker_ptr[rank + 1] = total; // shift 1 to right for constructing the indptr
collective::GlobalSum(info, global_worker_ptr.data(), global_worker_ptr.size());
rc = collective::GlobalSum(ctx, info,
linalg::MakeVec(global_worker_ptr.data(), global_worker_ptr.size()));
std::partial_sum(global_worker_ptr.cbegin(), global_worker_ptr.cend(), global_worker_ptr.begin());
// total number of categories in all workers with all features
auto gtotal = global_worker_ptr.back();
@ -251,7 +256,8 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(Context const*, MetaInfo
CHECK_EQ(rank_size, total);
std::copy(flatten.cbegin(), flatten.cend(), global_categories.begin() + rank_begin);
// gather values from all workers.
collective::GlobalSum(info, global_categories.data(), global_categories.size());
rc = collective::GlobalSum(ctx, info,
linalg::MakeVec(global_categories.data(), global_categories.size()));
QuantileAllreduce<float> allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs,
categories_.size()};
ParallelFor(categories_.size(), n_threads_, [&](auto fidx) {
@ -293,7 +299,9 @@ void SketchContainerImpl<WQSketch>::AllReduce(
// Prune the intermediate num cuts for synchronization.
std::vector<bst_row_t> global_column_size(columns_size_);
collective::GlobalSum(info, &global_column_size);
auto rc = collective::GlobalSum(
ctx, info, linalg::MakeVec(global_column_size.data(), global_column_size.size()));
collective::SafeColl(rc);
ParallelFor(sketches_.size(), n_threads_, [&](size_t i) {
int32_t intermediate_num_cuts = static_cast<int32_t>(

View File

@ -1,5 +1,5 @@
/**
* Copyright 2014-2023 by XGBoost Contributors
* Copyright 2014-2024, XGBoost Contributors
* \file learner.cc
* \brief Implementation of learning algorithm.
* \author Tianqi Chen
@ -846,7 +846,7 @@ class LearnerConfiguration : public Learner {
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
base_score->Reshape(1);
collective::ApplyWithLabels(info, base_score->Data(),
collective::ApplyWithLabels(this->Ctx(), info, base_score->Data(),
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
}
};
@ -1472,7 +1472,7 @@ class LearnerImpl : public LearnerIO {
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info,
std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) {
out_gpair->Reshape(info.num_row_, this->learner_model_param_.OutputLength());
collective::ApplyWithLabels(info, out_gpair->Data(),
collective::ApplyWithLabels(&ctx_, info, out_gpair->Data(),
[&] { obj_->GetGradient(preds, info, iter, out_gpair); });
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021-2023 by XGBoost Contributors
* Copyright 2021-2024, XGBoost Contributors
*/
#include "auc.h"
@ -112,7 +112,9 @@ double MultiClassOVR(Context const *ctx, common::Span<float const> predts, MetaI
// we have 2 averages going in here, first is among workers, second is among
// classes. allreduce sums up fp/tp auc for each class.
collective::GlobalSum(info, &results.Values());
auto rc = collective::GlobalSum(ctx, info, results);
collective::SafeColl(rc);
double auc_sum{0};
double tp_sum{0};
for (size_t c = 0; c < n_classes; ++c) {
@ -286,7 +288,7 @@ class EvalAUC : public MetricNoCache {
InvalidGroupAUC();
}
auc = collective::GlobalRatio(info, auc, static_cast<double>(valid_groups));
auc = collective::GlobalRatio(ctx_, info, auc, static_cast<double>(valid_groups));
if (!std::isnan(auc)) {
CHECK_LE(auc, 1) << "Total AUC across groups: " << auc * valid_groups
<< ", valid groups: " << valid_groups;
@ -307,7 +309,7 @@ class EvalAUC : public MetricNoCache {
std::tie(fp, tp, auc) =
static_cast<Curve *>(this)->EvalBinary(preds, info);
}
auc = collective::GlobalRatio(info, auc, fp * tp);
auc = collective::GlobalRatio(ctx_, info, auc, fp * tp);
if (!std::isnan(auc)) {
CHECK_LE(auc, 1.0);
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2015-2023 by XGBoost Contributors
* Copyright 2015-2024, XGBoost Contributors
* \file elementwise_metric.cu
* \brief evaluation metrics for elementwise binary or regression.
* \author Kailong Chen, Tianqi Chen
@ -19,6 +19,7 @@
#include "../common/quantile_loss_utils.h" // QuantileLossParam
#include "../common/threading_utils.h"
#include "metric_common.h"
#include "xgboost/collective/result.h" // for SafeColl
#include "xgboost/metric.h"
#if defined(XGBOOST_USE_CUDA)
@ -30,8 +31,7 @@
#include "../common/device_helpers.cuh"
#endif // XGBOOST_USE_CUDA
namespace xgboost {
namespace metric {
namespace xgboost::metric {
// tag the this file, used by force static link later.
DMLC_REGISTRY_FILE_TAG(elementwise_metric);
@ -199,7 +199,8 @@ class PseudoErrorLoss : public MetricNoCache {
return std::make_tuple(v, wt);
});
std::array<double, 2> dat{result.Residue(), result.Weights()};
collective::GlobalSum(info, &dat);
auto rc = collective::GlobalSum(ctx_, info, linalg::MakeVec(dat.data(), dat.size()));
collective::SafeColl(rc);
return EvalRowMAPE::GetFinal(dat[0], dat[1]);
}
};
@ -243,11 +244,11 @@ struct EvalError {
};
struct EvalPoissonNegLogLik {
const char *Name() const {
[[nodiscard]] const char *Name() const {
return "poisson-nloglik";
}
XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const {
[[nodiscard]] XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const {
const bst_float eps = 1e-16f;
if (py < eps) py = eps;
return common::LogGamma(y + 1.0f) + py - std::log(py) * y;
@ -266,9 +267,9 @@ struct EvalPoissonNegLogLik {
* predt >= 0
*/
struct EvalGammaDeviance {
const char *Name() const { return "gamma-deviance"; }
[[nodiscard]] const char *Name() const { return "gamma-deviance"; }
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float predt) const {
[[nodiscard]] XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float predt) const {
predt += kRtEps;
label += kRtEps;
return std::log(predt / label) + label / predt - 1;
@ -287,7 +288,7 @@ struct EvalGammaNLogLik {
return "gamma-nloglik";
}
XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const {
[[nodiscard]] XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const {
py = std::max(py, 1e-6f);
// hardcoded dispersion.
float constexpr kPsi = 1.0;
@ -313,7 +314,7 @@ struct EvalTweedieNLogLik {
CHECK(rho_ < 2 && rho_ >= 1)
<< "tweedie variance power must be in interval [1, 2)";
}
const char *Name() const {
[[nodiscard]] const char *Name() const {
static thread_local std::string name;
std::ostringstream os;
os << "tweedie-nloglik@" << rho_;
@ -321,7 +322,7 @@ struct EvalTweedieNLogLik {
return name.c_str();
}
XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float p) const {
[[nodiscard]] XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float p) const {
bst_float a = y * std::exp((1 - rho_) * std::log(p)) / (1 - rho_);
bst_float b = std::exp((2 - rho_) * std::log(p)) / (2 - rho_);
return -a + b;
@ -366,7 +367,8 @@ struct EvalEWiseBase : public MetricNoCache {
});
std::array<double, 2> dat{result.Residue(), result.Weights()};
collective::GlobalSum(info, &dat);
auto rc = collective::GlobalSum(ctx_, info, linalg::MakeVec(dat.data(), dat.size()));
collective::SafeColl(rc);
return Policy::GetFinal(dat[0], dat[1]);
}
@ -438,7 +440,8 @@ class QuantileError : public MetricNoCache {
if (info.num_row_ == 0) {
// empty DMatrix on distributed env
std::array<double, 2> dat{0.0, 0.0};
collective::GlobalSum(info, &dat);
auto rc = collective::GlobalSum(ctx_, info, linalg::MakeVec(dat.data(), dat.size()));
collective::SafeColl(rc);
CHECK_GT(dat[1], 0);
return dat[0] / dat[1];
}
@ -476,7 +479,8 @@ class QuantileError : public MetricNoCache {
return std::make_tuple(l, w);
});
std::array<double, 2> dat{result.Residue(), result.Weights()};
collective::GlobalSum(info, &dat);
auto rc = collective::GlobalSum(ctx, info, linalg::MakeVec(dat.data(), dat.size()));
collective::SafeColl(rc);
CHECK_GT(dat[1], 0);
return dat[0] / dat[1];
}
@ -501,5 +505,4 @@ class QuantileError : public MetricNoCache {
XGBOOST_REGISTER_METRIC(QuantileError, "quantile")
.describe("Quantile regression error.")
.set_body([](const char*) { return new QuantileError{}; });
} // namespace metric
} // namespace xgboost
} // namespace xgboost::metric

View File

@ -1,6 +1,5 @@
/*!
* Copyright 2018-2022 by Contributors
* \file metric_common.h
/**
* Copyright 2018-2024, Contributors
*/
#ifndef XGBOOST_METRIC_METRIC_COMMON_H_
#define XGBOOST_METRIC_METRIC_COMMON_H_
@ -24,7 +23,7 @@ class MetricNoCache : public Metric {
double Evaluate(HostDeviceVector<float> const &predts, std::shared_ptr<DMatrix> p_fmat) final {
double result{0.0};
auto const &info = p_fmat->Info();
collective::ApplyWithLabels(info, &result, sizeof(double),
collective::ApplyWithLabels(ctx_, info, &result, sizeof(double),
[&] { result = this->Eval(predts, info); });
return result;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2015-2023 by XGBoost Contributors
* Copyright 2015-2024, XGBoost Contributors
* \file multiclass_metric.cc
* \brief evaluation metrics for multiclass classification.
* \author Kailong Chen, Tianqi Chen
@ -24,8 +24,7 @@
#include "../common/device_helpers.cuh"
#endif // XGBOOST_USE_CUDA
namespace xgboost {
namespace metric {
namespace xgboost::metric {
// tag the this file, used by force static link later.
DMLC_REGISTRY_FILE_TAG(multiclass_metric);
@ -40,10 +39,9 @@ class MultiClassMetricsReduction {
public:
MultiClassMetricsReduction() = default;
PackedReduceResult
CpuReduceMetrics(const HostDeviceVector<bst_float> &weights,
const HostDeviceVector<bst_float> &labels,
const HostDeviceVector<bst_float> &preds,
[[nodiscard]] PackedReduceResult CpuReduceMetrics(const HostDeviceVector<bst_float>& weights,
const HostDeviceVector<bst_float>& labels,
const HostDeviceVector<bst_float>& preds,
const size_t n_class, int32_t n_threads) const {
size_t ndata = labels.Size();
@ -182,7 +180,8 @@ struct EvalMClassBase : public MetricNoCache {
dat[0] = result.Residue();
dat[1] = result.Weights();
}
collective::GlobalSum(info, &dat);
auto rc = collective::GlobalSum(ctx_, info, linalg::MakeVec(dat.data(), dat.size()));
collective::SafeColl(rc);
return Derived::GetFinal(dat[0], dat[1]);
}
/*!
@ -245,5 +244,4 @@ XGBOOST_REGISTER_METRIC(MatchError, "merror")
XGBOOST_REGISTER_METRIC(MultiLogLoss, "mlogloss")
.describe("Multiclass negative loglikelihood.")
.set_body([](const char*) { return new EvalMultiLogLoss(); });
} // namespace metric
} // namespace xgboost
} // namespace xgboost::metric

View File

@ -101,7 +101,7 @@ struct EvalAMS : public MetricNoCache {
}
}
const char* Name() const override {
[[nodiscard]] const char* Name() const override {
return name_.c_str();
}
@ -159,7 +159,7 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig {
exc.Rethrow();
}
return collective::GlobalRatio(info, sum_metric, static_cast<double>(ngroups));
return collective::GlobalRatio(ctx_, info, sum_metric, static_cast<double>(ngroups));
}
[[nodiscard]] const char* Name() const override {
@ -274,7 +274,7 @@ class EvalRankWithCache : public Metric {
double Evaluate(HostDeviceVector<float> const& preds, std::shared_ptr<DMatrix> p_fmat) override {
double result{0.0};
auto const& info = p_fmat->Info();
collective::ApplyWithLabels(info, &result, sizeof(double), [&] {
collective::ApplyWithLabels(ctx_, info, &result, sizeof(double), [&] {
auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_);
if (p_cache->Param() != param_) {
p_cache = cache_.ResetItem(p_fmat, ctx_, info, param_);
@ -294,9 +294,10 @@ class EvalRankWithCache : public Metric {
};
namespace {
double Finalize(Context const*, MetaInfo const& info, double score, double sw) {
double Finalize(Context const* ctx, MetaInfo const& info, double score, double sw) {
std::array<double, 2> dat{score, sw};
collective::GlobalSum(info, &dat);
auto rc = collective::GlobalSum(ctx, info, linalg::MakeVec(dat.data(), 2));
collective::SafeColl(rc);
std::tie(score, sw) = std::tuple_cat(dat);
if (sw > 0.0) {
score = score / sw;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2023 by XGBoost Contributors
* Copyright 2020-2024, XGBoost Contributors
*/
#include <dmlc/registry.h>
#include <thrust/iterator/counting_iterator.h> // for make_counting_iterator

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2023 by Contributors
* Copyright 2019-2024, Contributors
* \file survival_metric.cu
* \brief Metrics for survival analysis
* \author Avinash Barnwal, Hyunsu Cho and Toby Hocking
@ -30,8 +30,7 @@ using ProbabilityDistributionType = xgboost::common::ProbabilityDistributionType
template <typename Distribution>
using AFTLoss = xgboost::common::AFTLoss<Distribution>;
namespace xgboost {
namespace metric {
namespace xgboost::metric {
// tag the this file, used by force static link later.
DMLC_REGISTRY_FILE_TAG(survival_metric);
@ -43,12 +42,11 @@ class ElementWiseSurvivalMetricsReduction {
policy_ = policy;
}
PackedReduceResult
CpuReduceMetrics(const HostDeviceVector<bst_float> &weights,
const HostDeviceVector<bst_float> &labels_lower_bound,
const HostDeviceVector<bst_float> &labels_upper_bound,
const HostDeviceVector<bst_float> &preds,
int32_t n_threads) const {
[[nodiscard]] PackedReduceResult CpuReduceMetrics(
const HostDeviceVector<bst_float>& weights,
const HostDeviceVector<bst_float>& labels_lower_bound,
const HostDeviceVector<bst_float>& labels_upper_bound,
const HostDeviceVector<bst_float>& preds, int32_t n_threads) const {
size_t ndata = labels_lower_bound.Size();
CHECK_EQ(ndata, labels_upper_bound.Size());
@ -155,7 +153,7 @@ class ElementWiseSurvivalMetricsReduction {
struct EvalIntervalRegressionAccuracy {
void Configure(const Args&) {}
const char* Name() const {
[[nodiscard]] const char* Name() const {
return "interval-regression-accuracy";
}
@ -177,7 +175,7 @@ struct EvalAFTNLogLik {
param_.UpdateAllowUnknown(args);
}
const char* Name() const {
[[nodiscard]] const char* Name() const {
return "aft-nloglik";
}
@ -213,7 +211,8 @@ struct EvalEWiseSurvivalBase : public MetricNoCache {
info.labels_upper_bound_, preds);
std::array<double, 2> dat{result.Residue(), result.Weights()};
collective::GlobalSum(info, &dat);
auto rc = collective::GlobalSum(ctx_, info, linalg::MakeVec(dat.data(), dat.size()));
collective::SafeColl(rc);
return Policy::GetFinal(dat[0], dat[1]);
}
@ -230,7 +229,7 @@ struct EvalEWiseSurvivalBase : public MetricNoCache {
// This class exists because we want to perform dispatch according to the distribution type at
// configuration time, not at prediction time.
struct AFTNLogLikDispatcher : public MetricNoCache {
const char* Name() const override {
[[nodiscard]] const char* Name() const override {
return "aft-nloglik";
}
@ -282,5 +281,4 @@ XGBOOST_REGISTER_METRIC(IntervalRegressionAccuracy, "interval-regression-accurac
return new EvalEWiseSurvivalBase<EvalIntervalRegressionAccuracy>();
});
} // namespace metric
} // namespace xgboost
} // namespace xgboost::metric

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022-2023 by XGBoost Contributors
* Copyright 2022-2024, XGBoost Contributors
*/
#include "adaptive.h"
@ -85,7 +85,7 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
size_t n_leaf = nidx.size();
if (nptr.empty()) {
std::vector<float> quantiles;
UpdateLeafValues(&quantiles, nidx, info, learning_rate, p_tree);
UpdateLeafValues(ctx, &quantiles, nidx, info, learning_rate, p_tree);
return;
}
@ -100,7 +100,7 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
predt.Size() / info.num_row_);
collective::ApplyWithLabels(
info, static_cast<void*>(quantiles.data()), quantiles.size() * sizeof(float), [&] {
ctx, info, static_cast<void*>(quantiles.data()), quantiles.size() * sizeof(float), [&] {
// loop over each leaf
common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) {
auto nidx = h_node_idx[k];
@ -134,7 +134,7 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
});
});
UpdateLeafValues(&quantiles, nidx, info, learning_rate, p_tree);
UpdateLeafValues(ctx, &quantiles, nidx, info, learning_rate, p_tree);
}
#if !defined(XGBOOST_USE_CUDA)

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022-2023 by XGBoost Contributors
* Copyright 2022-2024, XGBoost Contributors
*/
#include <thrust/sort.h>
@ -150,7 +150,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
if (nptr.Empty()) {
std::vector<float> quantiles;
UpdateLeafValues(&quantiles, nidx.ConstHostVector(), info, learning_rate, p_tree);
UpdateLeafValues(ctx, &quantiles, nidx.ConstHostVector(), info, learning_rate, p_tree);
}
predt.SetDevice(ctx->Device());
@ -160,7 +160,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
auto t_predt = d_predt.Slice(linalg::All(), group_idx);
HostDeviceVector<float> quantiles;
collective::ApplyWithLabels(info, &quantiles, [&] {
collective::ApplyWithLabels(ctx, info, &quantiles, [&] {
auto d_labels = info.labels.View(ctx->Device()).Slice(linalg::All(), IdxY(info, group_idx));
auto d_row_index = dh::ToSpan(ridx);
auto seg_beg = nptr.DevicePointer();
@ -186,6 +186,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
w_it + d_weights.size(), &quantiles);
}
});
UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), info, learning_rate, p_tree);
UpdateLeafValues(ctx, &quantiles.HostVector(), nidx.ConstHostVector(), info, learning_rate,
p_tree);
}
} // namespace xgboost::obj::detail

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022-2023 by XGBoost Contributors
* Copyright 2022-2024, XGBoost Contributors
*/
#pragma once
@ -17,8 +17,7 @@
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/tree_model.h" // RegTree
namespace xgboost {
namespace obj {
namespace xgboost::obj {
namespace detail {
inline void FillMissingLeaf(std::vector<bst_node_t> const& maybe_missing,
std::vector<bst_node_t>* p_nidx, std::vector<size_t>* p_nptr) {
@ -36,13 +35,14 @@ inline void FillMissingLeaf(std::vector<bst_node_t> const& maybe_missing,
}
}
inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_node_t> const& nidx,
MetaInfo const& info, float learning_rate, RegTree* p_tree) {
inline void UpdateLeafValues(Context const* ctx, std::vector<float>* p_quantiles,
std::vector<bst_node_t> const& nidx, MetaInfo const& info,
float learning_rate, RegTree* p_tree) {
auto& tree = *p_tree;
auto& quantiles = *p_quantiles;
auto const& h_node_idx = nidx;
size_t n_leaf = collective::GlobalMax(info, h_node_idx.size());
size_t n_leaf = collective::GlobalMax(ctx, info, h_node_idx.size());
CHECK(quantiles.empty() || quantiles.size() == n_leaf);
if (quantiles.empty()) {
quantiles.resize(n_leaf, std::numeric_limits<float>::quiet_NaN());
@ -52,12 +52,16 @@ inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_no
std::vector<int32_t> n_valids(quantiles.size());
std::transform(quantiles.cbegin(), quantiles.cend(), n_valids.begin(),
[](float q) { return static_cast<int32_t>(!std::isnan(q)); });
collective::GlobalSum(info, &n_valids);
auto rc = collective::GlobalSum(ctx, info, linalg::MakeVec(n_valids.data(), n_valids.size()));
collective::SafeColl(rc);
// convert to 0 for all reduce
std::replace_if(
quantiles.begin(), quantiles.end(), [](float q) { return std::isnan(q); }, 0.f);
// use the mean value
collective::GlobalSum(info, &quantiles);
rc = collective::GlobalSum(ctx, info, linalg::MakeVec(quantiles.data(), quantiles.size()));
collective::SafeColl(rc);
for (size_t i = 0; i < n_leaf; ++i) {
if (n_valids[i] > 0) {
quantiles[i] /= static_cast<float>(n_valids[i]);
@ -105,5 +109,4 @@ inline void UpdateTreeLeaf(Context const* ctx, HostDeviceVector<bst_node_t> cons
predt, alpha, p_tree);
}
}
} // namespace obj
} // namespace xgboost
} // namespace xgboost::obj

View File

@ -1,5 +1,5 @@
/**
* Copyright 2023 by XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*/
#include <array> // std::array
#include <cstddef> // std::size_t
@ -170,7 +170,9 @@ class QuantileRegression : public ObjFunction {
double meanq = temp(0) * sw;
std::array<double, 2> dat{meanq, sw};
collective::GlobalSum(info, &dat);
auto rc = collective::GlobalSum(ctx_, info, linalg::MakeVec(dat.data(), dat.size()));
collective::SafeColl(rc);
std::tie(meanq, sw) = std::tuple_cat(dat);
meanq /= (sw + kRtEps);
base_score->Reshape(1);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2015-2023 by XGBoost Contributors
* Copyright 2015-2024, XGBoost Contributors
* \file regression_obj.cu
* \brief Definition of single-value regression and classification objectives.
* \author Tianqi Chen, Kailong Chen
@ -672,8 +672,12 @@ class MeanAbsoluteError : public ObjFunction {
std::transform(linalg::cbegin(out), linalg::cend(out), linalg::begin(out),
[w](float v) { return v * w; });
collective::GlobalSum(info, &out.Values());
collective::GlobalSum(info, &w, 1);
auto rc = collective::Success() << [&] {
return collective::GlobalSum(ctx_, info, out);
} << [&] {
return collective::GlobalSum(ctx_, info, linalg::MakeVec(&w, 1));
};
collective::SafeColl(rc);
if (common::CloseTo(w, 0.0)) {
// Mostly for handling empty dataset test.

View File

@ -1,7 +1,7 @@
/**
* Copyright 2022 by XGBoost Contributors
* Copyright 2022-2024, XGBoost Contributors
*
* \brief Utilities for estimating initial score.
* @brief Utilities for estimating initial score.
*/
#include "fit_stump.h"
@ -44,8 +44,11 @@ void FitStump(Context const* ctx, MetaInfo const& info,
}
}
CHECK(h_sum.CContiguous());
collective::GlobalSum(info, reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
auto as_double = linalg::MakeTensorView(
ctx, common::Span{reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2},
h_sum.Size() * 2);
auto rc = collective::GlobalSum(ctx, info, as_double);
collective::SafeColl(rc);
for (std::size_t i = 0; i < h_sum.Size(); ++i) {
out(i) = static_cast<float>(CalcUnregularizedWeight(h_sum(i).GetGrad(), h_sum(i).GetHess()));

View File

@ -1,7 +1,7 @@
/**
* Copyright 2022-2023 by XGBoost Contributors
* Copyright 2022-2024, XGBoost Contributors
*
* \brief Utilities for estimating initial score.
* @brief Utilities for estimating initial score.
*/
#if !defined(NOMINMAX) && defined(_WIN32)
#define NOMINMAX
@ -11,8 +11,7 @@
#include <cstddef> // std::size_t
#include "../collective/aggregator.cuh"
#include "../collective/communicator-inl.cuh"
#include "../collective/aggregator.cuh" // for GlobalSum
#include "../common/device_helpers.cuh" // dh::MakeTransformIterator
#include "fit_stump.h"
#include "xgboost/base.h" // GradientPairPrecise, GradientPair, XGBOOST_DEVICE

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2023 by XGBoost Contributors
* Copyright 2020-2024, XGBoost Contributors
*/
#include <thrust/iterator/transform_iterator.h>
#include <thrust/reduce.h>
@ -52,7 +52,7 @@ struct Clip : public thrust::unary_function<GradientPair, Pair> {
*
* to avoid outliers, as the full reduction is reproducible on GPU with reduction tree.
*/
GradientQuantiser::GradientQuantiser(Context const*, common::Span<GradientPair const> gpair,
GradientQuantiser::GradientQuantiser(Context const* ctx, common::Span<GradientPair const> gpair,
MetaInfo const& info) {
using GradientSumT = GradientPairPrecise;
using T = typename GradientSumT::ValueT;
@ -65,11 +65,14 @@ GradientQuantiser::GradientQuantiser(Context const*, common::Span<GradientPair c
// Treat pair as array of 4 primitive types to allreduce
using ReduceT = typename decltype(p.first)::ValueT;
static_assert(sizeof(Pair) == sizeof(ReduceT) * 4, "Expected to reduce four elements.");
collective::GlobalSum(info, reinterpret_cast<ReduceT*>(&p), 4);
auto rc = collective::GlobalSum(ctx, info, linalg::MakeVec(reinterpret_cast<ReduceT*>(&p), 4));
collective::SafeColl(rc);
GradientPair positive_sum{p.first}, negative_sum{p.second};
std::size_t total_rows = gpair.size();
collective::GlobalSum(info, &total_rows, 1);
rc = collective::GlobalSum(ctx, info, linalg::MakeVec(&total_rows, 1));
collective::SafeColl(rc);
auto histogram_rounding =
GradientSumT{common::CreateRoundingFactor<T>(

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021-2023 by XGBoost contributors
* Copyright 2021-2024, XGBoost contributors
*
* \brief Implementation for the approx tree method.
*/
@ -107,7 +107,10 @@ class GloablApproxBuilder {
for (auto const &g : gpair) {
root_sum.Add(g);
}
collective::GlobalSum(p_fmat->Info(), reinterpret_cast<double *>(&root_sum), 2);
auto rc = collective::GlobalSum(ctx_, p_fmat->Info(),
linalg::MakeVec(reinterpret_cast<double *>(&root_sum), 2));
collective::SafeColl(rc);
std::vector<CPUExpandEntry> nodes{best};
this->histogram_builder_.BuildRootHist(p_fmat, p_tree, partitioner_,
linalg::MakeTensorView(ctx_, gpair, gpair.size(), 1),

View File

@ -1,5 +1,5 @@
/**
* Copyright 2017-2023 by XGBoost contributors
* Copyright 2017-2024, XGBoost contributors
*/
#include <thrust/copy.h>
#include <thrust/reduce.h>
@ -729,7 +729,9 @@ struct GPUHistMakerDevice {
dh::Reduce(ctx_->CUDACtx()->CTP(), gpair_it, gpair_it + gpair.size(),
GradientPairInt64{}, thrust::plus<GradientPairInt64>{});
using ReduceT = typename decltype(root_sum_quantised)::ValueT;
collective::GlobalSum(info_, reinterpret_cast<ReduceT*>(&root_sum_quantised), 2);
auto rc = collective::GlobalSum(
ctx_, info_, linalg::MakeVec(reinterpret_cast<ReduceT*>(&root_sum_quantised), 2));
collective::SafeColl(rc);
hist.AllocateHistograms({kRootNIdx});
this->BuildHist(kRootNIdx);

View File

@ -199,8 +199,10 @@ class MultiTargetHistBuilder {
}
}
CHECK(root_sum.CContiguous());
collective::GlobalSum(p_fmat->Info(), reinterpret_cast<double *>(root_sum.Values().data()),
root_sum.Size() * 2);
auto rc = collective::GlobalSum(
ctx_, p_fmat->Info(),
linalg::MakeVec(reinterpret_cast<double *>(root_sum.Values().data()), root_sum.Size() * 2));
collective::SafeColl(rc);
histogram_builder_->BuildRootHist(p_fmat, p_tree, partitioner_, gpair, best, HistBatch(param_));
@ -408,7 +410,9 @@ class HistUpdater {
for (auto const &grad : gpair_h) {
grad_stat.Add(grad.GetGrad(), grad.GetHess());
}
collective::GlobalSum(p_fmat->Info(), reinterpret_cast<double *>(&grad_stat), 2);
auto rc = collective::GlobalSum(ctx_, p_fmat->Info(),
linalg::MakeVec(reinterpret_cast<double *>(&grad_stat), 2));
collective::SafeColl(rc);
}
auto weight = evaluator_->InitRoot(GradStats{grad_stat});
@ -471,6 +475,7 @@ class QuantileHistMaker : public TreeUpdater {
std::unique_ptr<HistUpdater> p_impl_{nullptr};
std::unique_ptr<MultiTargetHistBuilder> p_mtimpl_{nullptr};
std::shared_ptr<common::ColumnSampler> column_sampler_;
common::Monitor monitor_;
ObjInfo const *task_{nullptr};
HistMakerTrainParam hist_param_;