Refactor device communicator to make allreduce more flexible (#9295)
This commit is contained in:
@@ -11,7 +11,7 @@
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "../collective/device_communicator.cuh"
|
||||
#include "../collective/communicator-inl.cuh"
|
||||
#include "../common/algorithm.cuh" // SegmentedArgSort
|
||||
#include "../common/optional_weight.h" // OptionalWeights
|
||||
#include "../common/threading_utils.cuh" // UnravelTrapeziodIdx,SegmentedTrapezoidThreads
|
||||
@@ -205,8 +205,7 @@ double ScaleClasses(common::Span<double> results, common::Span<double> local_are
|
||||
if (collective::IsDistributed()) {
|
||||
int32_t device = dh::CurrentDevice();
|
||||
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device);
|
||||
auto* communicator = collective::Communicator::GetDevice(device);
|
||||
communicator->AllReduceSum(results.data(), results.size());
|
||||
collective::AllReduce<collective::Operation::kSum>(device, results.data(), results.size());
|
||||
}
|
||||
auto reduce_in = dh::MakeTransformIterator<Pair>(
|
||||
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
|
||||
|
||||
Reference in New Issue
Block a user