More collective aggregators (#9060)
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include <cinttypes> // std::int32_t
|
||||
#include <cstddef> // std::size_t
|
||||
|
||||
#include "../collective/aggregator.h"
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/common.h" // AssertGPUSupport
|
||||
#include "../common/numeric.h" // cpu_impl::Reduce
|
||||
@@ -45,10 +46,7 @@ void FitStump(Context const* ctx, MetaInfo const& info,
|
||||
}
|
||||
CHECK(h_sum.CContiguous());
|
||||
|
||||
if (info.IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(
|
||||
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
|
||||
}
|
||||
collective::GlobalSum(info, reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
|
||||
|
||||
for (std::size_t i = 0; i < h_sum.Size(); ++i) {
|
||||
out(i) = static_cast<float>(CalcUnregularizedWeight(h_sum(i).GetGrad(), h_sum(i).GetHess()));
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "../collective/aggregator.h"
|
||||
#include "../common/random.h"
|
||||
#include "../data/gradient_index.h"
|
||||
#include "common_row_partitioner.h"
|
||||
@@ -92,9 +93,7 @@ class GloablApproxBuilder {
|
||||
for (auto const &g : gpair) {
|
||||
root_sum.Add(g);
|
||||
}
|
||||
if (p_fmat->Info().IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&root_sum), 2);
|
||||
}
|
||||
collective::GlobalSum(p_fmat->Info(), reinterpret_cast<double *>(&root_sum), 2);
|
||||
std::vector<CPUExpandEntry> nodes{best};
|
||||
size_t i = 0;
|
||||
auto space = ConstructHistSpace(partitioner_, nodes);
|
||||
|
||||
Reference in New Issue
Block a user