Require context in aggregators. (#10075)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
/**
|
||||
* Copyright 2022 by XGBoost Contributors
|
||||
* Copyright 2022-2024, XGBoost Contributors
|
||||
*
|
||||
* \brief Utilities for estimating initial score.
|
||||
* @brief Utilities for estimating initial score.
|
||||
*/
|
||||
#include "fit_stump.h"
|
||||
|
||||
@@ -44,8 +44,11 @@ void FitStump(Context const* ctx, MetaInfo const& info,
|
||||
}
|
||||
}
|
||||
CHECK(h_sum.CContiguous());
|
||||
|
||||
collective::GlobalSum(info, reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
|
||||
auto as_double = linalg::MakeTensorView(
|
||||
ctx, common::Span{reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2},
|
||||
h_sum.Size() * 2);
|
||||
auto rc = collective::GlobalSum(ctx, info, as_double);
|
||||
collective::SafeColl(rc);
|
||||
|
||||
for (std::size_t i = 0; i < h_sum.Size(); ++i) {
|
||||
out(i) = static_cast<float>(CalcUnregularizedWeight(h_sum(i).GetGrad(), h_sum(i).GetHess()));
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
/**
|
||||
* Copyright 2022-2023 by XGBoost Contributors
|
||||
* Copyright 2022-2024, XGBoost Contributors
|
||||
*
|
||||
* \brief Utilities for estimating initial score.
|
||||
* @brief Utilities for estimating initial score.
|
||||
*/
|
||||
#if !defined(NOMINMAX) && defined(_WIN32)
|
||||
#define NOMINMAX
|
||||
#endif // !defined(NOMINMAX)
|
||||
#include <thrust/execution_policy.h> // cuda::par
|
||||
#include <thrust/iterator/counting_iterator.h> // thrust::make_counting_iterator
|
||||
#endif // !defined(NOMINMAX)
|
||||
#include <thrust/execution_policy.h> // cuda::par
|
||||
#include <thrust/iterator/counting_iterator.h> // thrust::make_counting_iterator
|
||||
|
||||
#include <cstddef> // std::size_t
|
||||
#include <cstddef> // std::size_t
|
||||
|
||||
#include "../collective/aggregator.cuh"
|
||||
#include "../collective/communicator-inl.cuh"
|
||||
#include "../common/device_helpers.cuh" // dh::MakeTransformIterator
|
||||
#include "../collective/aggregator.cuh" // for GlobalSum
|
||||
#include "../common/device_helpers.cuh" // dh::MakeTransformIterator
|
||||
#include "fit_stump.h"
|
||||
#include "xgboost/base.h" // GradientPairPrecise, GradientPair, XGBOOST_DEVICE
|
||||
#include "xgboost/context.h" // Context
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2020-2023 by XGBoost Contributors
|
||||
* Copyright 2020-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <thrust/iterator/transform_iterator.h>
|
||||
#include <thrust/reduce.h>
|
||||
@@ -52,7 +52,7 @@ struct Clip : public thrust::unary_function<GradientPair, Pair> {
|
||||
*
|
||||
* to avoid outliers, as the full reduction is reproducible on GPU with reduction tree.
|
||||
*/
|
||||
GradientQuantiser::GradientQuantiser(Context const*, common::Span<GradientPair const> gpair,
|
||||
GradientQuantiser::GradientQuantiser(Context const* ctx, common::Span<GradientPair const> gpair,
|
||||
MetaInfo const& info) {
|
||||
using GradientSumT = GradientPairPrecise;
|
||||
using T = typename GradientSumT::ValueT;
|
||||
@@ -65,11 +65,14 @@ GradientQuantiser::GradientQuantiser(Context const*, common::Span<GradientPair c
|
||||
// Treat pair as array of 4 primitive types to allreduce
|
||||
using ReduceT = typename decltype(p.first)::ValueT;
|
||||
static_assert(sizeof(Pair) == sizeof(ReduceT) * 4, "Expected to reduce four elements.");
|
||||
collective::GlobalSum(info, reinterpret_cast<ReduceT*>(&p), 4);
|
||||
auto rc = collective::GlobalSum(ctx, info, linalg::MakeVec(reinterpret_cast<ReduceT*>(&p), 4));
|
||||
collective::SafeColl(rc);
|
||||
|
||||
GradientPair positive_sum{p.first}, negative_sum{p.second};
|
||||
|
||||
std::size_t total_rows = gpair.size();
|
||||
collective::GlobalSum(info, &total_rows, 1);
|
||||
rc = collective::GlobalSum(ctx, info, linalg::MakeVec(&total_rows, 1));
|
||||
collective::SafeColl(rc);
|
||||
|
||||
auto histogram_rounding =
|
||||
GradientSumT{common::CreateRoundingFactor<T>(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2021-2023 by XGBoost contributors
|
||||
* Copyright 2021-2024, XGBoost contributors
|
||||
*
|
||||
* \brief Implementation for the approx tree method.
|
||||
*/
|
||||
@@ -107,7 +107,10 @@ class GloablApproxBuilder {
|
||||
for (auto const &g : gpair) {
|
||||
root_sum.Add(g);
|
||||
}
|
||||
collective::GlobalSum(p_fmat->Info(), reinterpret_cast<double *>(&root_sum), 2);
|
||||
auto rc = collective::GlobalSum(ctx_, p_fmat->Info(),
|
||||
linalg::MakeVec(reinterpret_cast<double *>(&root_sum), 2));
|
||||
collective::SafeColl(rc);
|
||||
|
||||
std::vector<CPUExpandEntry> nodes{best};
|
||||
this->histogram_builder_.BuildRootHist(p_fmat, p_tree, partitioner_,
|
||||
linalg::MakeTensorView(ctx_, gpair, gpair.size(), 1),
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2017-2023 by XGBoost contributors
|
||||
* Copyright 2017-2024, XGBoost contributors
|
||||
*/
|
||||
#include <thrust/copy.h>
|
||||
#include <thrust/reduce.h>
|
||||
@@ -729,7 +729,9 @@ struct GPUHistMakerDevice {
|
||||
dh::Reduce(ctx_->CUDACtx()->CTP(), gpair_it, gpair_it + gpair.size(),
|
||||
GradientPairInt64{}, thrust::plus<GradientPairInt64>{});
|
||||
using ReduceT = typename decltype(root_sum_quantised)::ValueT;
|
||||
collective::GlobalSum(info_, reinterpret_cast<ReduceT*>(&root_sum_quantised), 2);
|
||||
auto rc = collective::GlobalSum(
|
||||
ctx_, info_, linalg::MakeVec(reinterpret_cast<ReduceT*>(&root_sum_quantised), 2));
|
||||
collective::SafeColl(rc);
|
||||
|
||||
hist.AllocateHistograms({kRootNIdx});
|
||||
this->BuildHist(kRootNIdx);
|
||||
|
||||
@@ -199,8 +199,10 @@ class MultiTargetHistBuilder {
|
||||
}
|
||||
}
|
||||
CHECK(root_sum.CContiguous());
|
||||
collective::GlobalSum(p_fmat->Info(), reinterpret_cast<double *>(root_sum.Values().data()),
|
||||
root_sum.Size() * 2);
|
||||
auto rc = collective::GlobalSum(
|
||||
ctx_, p_fmat->Info(),
|
||||
linalg::MakeVec(reinterpret_cast<double *>(root_sum.Values().data()), root_sum.Size() * 2));
|
||||
collective::SafeColl(rc);
|
||||
|
||||
histogram_builder_->BuildRootHist(p_fmat, p_tree, partitioner_, gpair, best, HistBatch(param_));
|
||||
|
||||
@@ -408,7 +410,9 @@ class HistUpdater {
|
||||
for (auto const &grad : gpair_h) {
|
||||
grad_stat.Add(grad.GetGrad(), grad.GetHess());
|
||||
}
|
||||
collective::GlobalSum(p_fmat->Info(), reinterpret_cast<double *>(&grad_stat), 2);
|
||||
auto rc = collective::GlobalSum(ctx_, p_fmat->Info(),
|
||||
linalg::MakeVec(reinterpret_cast<double *>(&grad_stat), 2));
|
||||
collective::SafeColl(rc);
|
||||
}
|
||||
|
||||
auto weight = evaluator_->InitRoot(GradStats{grad_stat});
|
||||
@@ -471,6 +475,7 @@ class QuantileHistMaker : public TreeUpdater {
|
||||
std::unique_ptr<HistUpdater> p_impl_{nullptr};
|
||||
std::unique_ptr<MultiTargetHistBuilder> p_mtimpl_{nullptr};
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
||||
|
||||
common::Monitor monitor_;
|
||||
ObjInfo const *task_{nullptr};
|
||||
HistMakerTrainParam hist_param_;
|
||||
|
||||
Reference in New Issue
Block a user