More collective aggregators (#9060)

This commit is contained in:
Rong Ou
2023-04-21 12:32:05 -07:00
committed by GitHub
parent 7032981350
commit 8dbe0510de
11 changed files with 107 additions and 89 deletions

View File

@@ -8,6 +8,7 @@
#include <cinttypes> // std::int32_t
#include <cstddef> // std::size_t
#include "../collective/aggregator.h"
#include "../collective/communicator-inl.h"
#include "../common/common.h" // AssertGPUSupport
#include "../common/numeric.h" // cpu_impl::Reduce
@@ -45,10 +46,7 @@ void FitStump(Context const* ctx, MetaInfo const& info,
}
CHECK(h_sum.CContiguous());
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
}
collective::GlobalSum(info, reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
for (std::size_t i = 0; i < h_sum.Size(); ++i) {
out(i) = static_cast<float>(CalcUnregularizedWeight(h_sum(i).GetGrad(), h_sum(i).GetHess()));

View File

@@ -7,6 +7,7 @@
#include <memory>
#include <vector>
#include "../collective/aggregator.h"
#include "../common/random.h"
#include "../data/gradient_index.h"
#include "common_row_partitioner.h"
@@ -92,9 +93,7 @@ class GloablApproxBuilder {
for (auto const &g : gpair) {
root_sum.Add(g);
}
if (p_fmat->Info().IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&root_sum), 2);
}
collective::GlobalSum(p_fmat->Info(), reinterpret_cast<double *>(&root_sum), 2);
std::vector<CPUExpandEntry> nodes{best};
size_t i = 0;
auto space = ConstructHistSpace(partitioner_, nodes);