Refactor device communicator to make allreduce more flexible (#9295)

This commit is contained in:
Rong Ou
2023-06-13 12:53:03 -07:00
committed by GitHub
parent c2f0486d37
commit e70810be8a
11 changed files with 190 additions and 106 deletions

View File

@@ -12,8 +12,7 @@
#include <memory>
#include <utility>
#include "../collective/communicator.h"
#include "../collective/device_communicator.cuh"
#include "../collective/communicator-inl.cuh"
#include "categorical.h"
#include "common.h"
#include "device_helpers.cuh"
@@ -510,7 +509,6 @@ void SketchContainer::AllReduce() {
}
timer_.Start(__func__);
auto* communicator = collective::Communicator::GetDevice(device_);
// Reduce the overhead on syncing.
size_t global_sum_rows = num_rows_;
collective::Allreduce<collective::Operation::kSum>(&global_sum_rows, 1);
@@ -531,14 +529,15 @@ void SketchContainer::AllReduce() {
auto offset = rank * d_columns_ptr.size();
thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(),
gathered_ptrs.begin() + offset);
communicator->AllReduceSum(gathered_ptrs.data().get(), gathered_ptrs.size());
collective::AllReduce<collective::Operation::kSum>(device_, gathered_ptrs.data().get(),
gathered_ptrs.size());
// Get the data from all workers.
std::vector<size_t> recv_lengths;
dh::caching_device_vector<char> recvbuf;
communicator->AllGatherV(this->Current().data().get(), dh::ToSpan(this->Current()).size_bytes(),
&recv_lengths, &recvbuf);
communicator->Synchronize();
collective::AllGatherV(device_, this->Current().data().get(),
dh::ToSpan(this->Current()).size_bytes(), &recv_lengths, &recvbuf);
collective::Synchronize(device_);
// Segment the received data.
auto s_recvbuf = dh::ToSpan(recvbuf);