Refactor device communicator to make allreduce more flexible (#9295)
This commit is contained in:
@@ -12,8 +12,7 @@
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "../collective/communicator.h"
|
||||
#include "../collective/device_communicator.cuh"
|
||||
#include "../collective/communicator-inl.cuh"
|
||||
#include "categorical.h"
|
||||
#include "common.h"
|
||||
#include "device_helpers.cuh"
|
||||
@@ -510,7 +509,6 @@ void SketchContainer::AllReduce() {
|
||||
}
|
||||
|
||||
timer_.Start(__func__);
|
||||
auto* communicator = collective::Communicator::GetDevice(device_);
|
||||
// Reduce the overhead on syncing.
|
||||
size_t global_sum_rows = num_rows_;
|
||||
collective::Allreduce<collective::Operation::kSum>(&global_sum_rows, 1);
|
||||
@@ -531,14 +529,15 @@ void SketchContainer::AllReduce() {
|
||||
auto offset = rank * d_columns_ptr.size();
|
||||
thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(),
|
||||
gathered_ptrs.begin() + offset);
|
||||
communicator->AllReduceSum(gathered_ptrs.data().get(), gathered_ptrs.size());
|
||||
collective::AllReduce<collective::Operation::kSum>(device_, gathered_ptrs.data().get(),
|
||||
gathered_ptrs.size());
|
||||
|
||||
// Get the data from all workers.
|
||||
std::vector<size_t> recv_lengths;
|
||||
dh::caching_device_vector<char> recvbuf;
|
||||
communicator->AllGatherV(this->Current().data().get(), dh::ToSpan(this->Current()).size_bytes(),
|
||||
&recv_lengths, &recvbuf);
|
||||
communicator->Synchronize();
|
||||
collective::AllGatherV(device_, this->Current().data().get(),
|
||||
dh::ToSpan(this->Current()).size_bytes(), &recv_lengths, &recvbuf);
|
||||
collective::Synchronize(device_);
|
||||
|
||||
// Segment the received data.
|
||||
auto s_recvbuf = dh::ToSpan(recvbuf);
|
||||
|
||||
Reference in New Issue
Block a user