Require context in aggregators. (#10075)

This commit is contained in:
Jiaming Yuan
2024-02-28 03:12:42 +08:00
committed by GitHub
parent 761845f594
commit 5ac233280e
23 changed files with 190 additions and 144 deletions

View File

@@ -1,7 +1,7 @@
/**
* Copyright 2022 by XGBoost Contributors
* Copyright 2022-2024, XGBoost Contributors
*
* \brief Utilities for estimating initial score.
* @brief Utilities for estimating initial score.
*/
#include "fit_stump.h"
@@ -44,8 +44,11 @@ void FitStump(Context const* ctx, MetaInfo const& info,
}
}
CHECK(h_sum.CContiguous());
collective::GlobalSum(info, reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
auto as_double = linalg::MakeTensorView(
ctx, common::Span{reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2},
h_sum.Size() * 2);
auto rc = collective::GlobalSum(ctx, info, as_double);
collective::SafeColl(rc);
for (std::size_t i = 0; i < h_sum.Size(); ++i) {
out(i) = static_cast<float>(CalcUnregularizedWeight(h_sum(i).GetGrad(), h_sum(i).GetHess()));

View File

@@ -1,19 +1,18 @@
/**
* Copyright 2022-2023 by XGBoost Contributors
* Copyright 2022-2024, XGBoost Contributors
*
* \brief Utilities for estimating initial score.
* @brief Utilities for estimating initial score.
*/
#if !defined(NOMINMAX) && defined(_WIN32)
#define NOMINMAX
#endif // !defined(NOMINMAX)
#include <thrust/execution_policy.h> // cuda::par
#include <thrust/iterator/counting_iterator.h> // thrust::make_counting_iterator
#endif // !defined(NOMINMAX)
#include <thrust/execution_policy.h> // cuda::par
#include <thrust/iterator/counting_iterator.h> // thrust::make_counting_iterator
#include <cstddef> // std::size_t
#include <cstddef> // std::size_t
#include "../collective/aggregator.cuh"
#include "../collective/communicator-inl.cuh"
#include "../common/device_helpers.cuh" // dh::MakeTransformIterator
#include "../collective/aggregator.cuh" // for GlobalSum
#include "../common/device_helpers.cuh" // dh::MakeTransformIterator
#include "fit_stump.h"
#include "xgboost/base.h" // GradientPairPrecise, GradientPair, XGBOOST_DEVICE
#include "xgboost/context.h" // Context

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020-2023 by XGBoost Contributors
* Copyright 2020-2024, XGBoost Contributors
*/
#include <thrust/iterator/transform_iterator.h>
#include <thrust/reduce.h>
@@ -52,7 +52,7 @@ struct Clip : public thrust::unary_function<GradientPair, Pair> {
*
* to avoid outliers, as the full reduction is reproducible on GPU with reduction tree.
*/
GradientQuantiser::GradientQuantiser(Context const*, common::Span<GradientPair const> gpair,
GradientQuantiser::GradientQuantiser(Context const* ctx, common::Span<GradientPair const> gpair,
MetaInfo const& info) {
using GradientSumT = GradientPairPrecise;
using T = typename GradientSumT::ValueT;
@@ -65,11 +65,14 @@ GradientQuantiser::GradientQuantiser(Context const*, common::Span<GradientPair c
// Treat pair as array of 4 primitive types to allreduce
using ReduceT = typename decltype(p.first)::ValueT;
static_assert(sizeof(Pair) == sizeof(ReduceT) * 4, "Expected to reduce four elements.");
collective::GlobalSum(info, reinterpret_cast<ReduceT*>(&p), 4);
auto rc = collective::GlobalSum(ctx, info, linalg::MakeVec(reinterpret_cast<ReduceT*>(&p), 4));
collective::SafeColl(rc);
GradientPair positive_sum{p.first}, negative_sum{p.second};
std::size_t total_rows = gpair.size();
collective::GlobalSum(info, &total_rows, 1);
rc = collective::GlobalSum(ctx, info, linalg::MakeVec(&total_rows, 1));
collective::SafeColl(rc);
auto histogram_rounding =
GradientSumT{common::CreateRoundingFactor<T>(

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2021-2023 by XGBoost contributors
* Copyright 2021-2024, XGBoost contributors
*
* \brief Implementation for the approx tree method.
*/
@@ -107,7 +107,10 @@ class GloablApproxBuilder {
for (auto const &g : gpair) {
root_sum.Add(g);
}
collective::GlobalSum(p_fmat->Info(), reinterpret_cast<double *>(&root_sum), 2);
auto rc = collective::GlobalSum(ctx_, p_fmat->Info(),
linalg::MakeVec(reinterpret_cast<double *>(&root_sum), 2));
collective::SafeColl(rc);
std::vector<CPUExpandEntry> nodes{best};
this->histogram_builder_.BuildRootHist(p_fmat, p_tree, partitioner_,
linalg::MakeTensorView(ctx_, gpair, gpair.size(), 1),

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2017-2023 by XGBoost contributors
* Copyright 2017-2024, XGBoost contributors
*/
#include <thrust/copy.h>
#include <thrust/reduce.h>
@@ -729,7 +729,9 @@ struct GPUHistMakerDevice {
dh::Reduce(ctx_->CUDACtx()->CTP(), gpair_it, gpair_it + gpair.size(),
GradientPairInt64{}, thrust::plus<GradientPairInt64>{});
using ReduceT = typename decltype(root_sum_quantised)::ValueT;
collective::GlobalSum(info_, reinterpret_cast<ReduceT*>(&root_sum_quantised), 2);
auto rc = collective::GlobalSum(
ctx_, info_, linalg::MakeVec(reinterpret_cast<ReduceT*>(&root_sum_quantised), 2));
collective::SafeColl(rc);
hist.AllocateHistograms({kRootNIdx});
this->BuildHist(kRootNIdx);

View File

@@ -199,8 +199,10 @@ class MultiTargetHistBuilder {
}
}
CHECK(root_sum.CContiguous());
collective::GlobalSum(p_fmat->Info(), reinterpret_cast<double *>(root_sum.Values().data()),
root_sum.Size() * 2);
auto rc = collective::GlobalSum(
ctx_, p_fmat->Info(),
linalg::MakeVec(reinterpret_cast<double *>(root_sum.Values().data()), root_sum.Size() * 2));
collective::SafeColl(rc);
histogram_builder_->BuildRootHist(p_fmat, p_tree, partitioner_, gpair, best, HistBatch(param_));
@@ -408,7 +410,9 @@ class HistUpdater {
for (auto const &grad : gpair_h) {
grad_stat.Add(grad.GetGrad(), grad.GetHess());
}
collective::GlobalSum(p_fmat->Info(), reinterpret_cast<double *>(&grad_stat), 2);
auto rc = collective::GlobalSum(ctx_, p_fmat->Info(),
linalg::MakeVec(reinterpret_cast<double *>(&grad_stat), 2));
collective::SafeColl(rc);
}
auto weight = evaluator_->InitRoot(GradStats{grad_stat});
@@ -471,6 +475,7 @@ class QuantileHistMaker : public TreeUpdater {
std::unique_ptr<HistUpdater> p_impl_{nullptr};
std::unique_ptr<MultiTargetHistBuilder> p_mtimpl_{nullptr};
std::shared_ptr<common::ColumnSampler> column_sampler_;
common::Monitor monitor_;
ObjInfo const *task_{nullptr};
HistMakerTrainParam hist_param_;