Refactor device communicator to make allreduce more flexible (#9295)

This commit is contained in:
Rong Ou
2023-06-13 12:53:03 -07:00
committed by GitHub
parent c2f0486d37
commit e70810be8a
11 changed files with 190 additions and 106 deletions

View File

@@ -11,7 +11,7 @@
#include <tuple>
#include <utility>
#include "../collective/device_communicator.cuh"
#include "../collective/communicator-inl.cuh"
#include "../common/algorithm.cuh" // SegmentedArgSort
#include "../common/optional_weight.h" // OptionalWeights
#include "../common/threading_utils.cuh" // UnravelTrapeziodIdx,SegmentedTrapezoidThreads
@@ -205,8 +205,7 @@ double ScaleClasses(common::Span<double> results, common::Span<double> local_are
if (collective::IsDistributed()) {
int32_t device = dh::CurrentDevice();
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device);
auto* communicator = collective::Communicator::GetDevice(device);
communicator->AllReduceSum(results.data(), results.size());
collective::AllReduce<collective::Operation::kSum>(device, results.data(), results.size());
}
auto reduce_in = dh::MakeTransformIterator<Pair>(
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {