Fix NCCL test hang (#9367)
This commit is contained in:
parent
41c6813496
commit
15ca12a77e
@ -29,10 +29,18 @@ DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
|
||||
old_device_ordinal = device_ordinal;
|
||||
old_world_size = communicator_->GetWorldSize();
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
if (type_ != CommunicatorType::kFederated) {
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal));
|
||||
} else {
|
||||
switch (type_) {
|
||||
case CommunicatorType::kRabit:
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
|
||||
break;
|
||||
case CommunicatorType::kFederated:
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
||||
break;
|
||||
case CommunicatorType::kInMemory:
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true));
|
||||
break;
|
||||
default:
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
|
||||
}
|
||||
#else
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
||||
|
||||
@ -7,8 +7,11 @@
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal)
|
||||
: device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} {
|
||||
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync)
|
||||
: device_ordinal_{device_ordinal},
|
||||
needs_sync_{needs_sync},
|
||||
world_size_{GetWorldSize()},
|
||||
rank_{GetRank()} {
|
||||
if (device_ordinal_ < 0) {
|
||||
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
|
||||
}
|
||||
@ -140,6 +143,9 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
|
||||
// First gather data from all the workers.
|
||||
dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
|
||||
nccl_comm_, cuda_stream_));
|
||||
if (needs_sync_) {
|
||||
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
|
||||
}
|
||||
|
||||
// Then reduce locally.
|
||||
auto *out_buffer = static_cast<char *>(send_receive_buffer);
|
||||
|
||||
@ -12,7 +12,20 @@ namespace collective {
|
||||
|
||||
class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
public:
|
||||
explicit NcclDeviceCommunicator(int device_ordinal);
|
||||
/**
|
||||
* @brief Construct a new NCCL communicator.
|
||||
* @param device_ordinal The GPU device id.
|
||||
* @param needs_sync Whether extra CUDA stream synchronization is needed.
|
||||
*
|
||||
* In multi-GPU tests when multiple NCCL communicators are created in the same process, sometimes
|
||||
* a deadlock happens because NCCL kernels are blocking. The extra CUDA stream synchronization
|
||||
* makes sure that the NCCL kernels are caught up, thus avoiding the deadlock.
|
||||
*
|
||||
* The Rabit communicator runs with one process per GPU, so the additional synchronization is not
|
||||
* needed. The in-memory communicator is used in tests with multiple threads, each thread
|
||||
* representing a rank/worker, so the additional synchronization is needed to avoid deadlocks.
|
||||
*/
|
||||
explicit NcclDeviceCommunicator(int device_ordinal, bool needs_sync);
|
||||
~NcclDeviceCommunicator() override;
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override;
|
||||
@ -60,6 +73,7 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
Operation op);
|
||||
|
||||
int const device_ordinal_;
|
||||
bool const needs_sync_;
|
||||
int const world_size_;
|
||||
int const rank_;
|
||||
ncclComm_t nccl_comm_{};
|
||||
|
||||
@ -16,7 +16,7 @@ namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) {
|
||||
auto construct = []() { NcclDeviceCommunicator comm{-1}; };
|
||||
auto construct = []() { NcclDeviceCommunicator comm{-1, false}; };
|
||||
EXPECT_THROW(construct(), dmlc::Error);
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user