[Breaking] Switch from rabit to the collective communicator (#8257)
* Switch from rabit to the collective communicator * fix size_t specialization * really fix size_t * try again * add include * more include * fix lint errors * remove rabit includes * fix pylint error * return dict from communicator context * fix communicator shutdown * fix dask test * reset communicator mocklist * fix distributed tests * do not save device communicator * fix jvm gpu tests * add python test for federated communicator * Update gputreeshap submodule Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -23,17 +23,28 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
|
||||
~DeviceCommunicatorAdapter() override = default;
|
||||
|
||||
void AllReduceSum(double *send_receive_buffer, int count) override {
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto size = count * sizeof(double);
|
||||
host_buffer_.reserve(size);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||
communicator_->AllReduce(host_buffer_.data(), count, DataType::kDouble, Operation::kSum);
|
||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||
void AllReduceSum(float *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kFloat>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(double *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kDouble>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kInt64>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kUInt64>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
int const world_size = communicator_->GetWorldSize();
|
||||
int const rank = communicator_->GetRank();
|
||||
@@ -66,6 +77,20 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
}
|
||||
|
||||
private:
|
||||
template <collective::DataType data_type, typename T>
|
||||
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto size = count * sizeof(T);
|
||||
host_buffer_.reserve(size);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||
communicator_->AllReduce(host_buffer_.data(), count, data_type, collective::Operation::kSum);
|
||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
int const device_ordinal_;
|
||||
Communicator *communicator_;
|
||||
/// Host buffer used to call communicator functions.
|
||||
|
||||
Reference in New Issue
Block a user