Fix NCCL test hang (#9367)
This commit is contained in:
parent
41c6813496
commit
15ca12a77e
@ -29,10 +29,18 @@ DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
|
|||||||
old_device_ordinal = device_ordinal;
|
old_device_ordinal = device_ordinal;
|
||||||
old_world_size = communicator_->GetWorldSize();
|
old_world_size = communicator_->GetWorldSize();
|
||||||
#ifdef XGBOOST_USE_NCCL
|
#ifdef XGBOOST_USE_NCCL
|
||||||
if (type_ != CommunicatorType::kFederated) {
|
switch (type_) {
|
||||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal));
|
case CommunicatorType::kRabit:
|
||||||
} else {
|
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
|
||||||
|
break;
|
||||||
|
case CommunicatorType::kFederated:
|
||||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
||||||
|
break;
|
||||||
|
case CommunicatorType::kInMemory:
|
||||||
|
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
||||||
|
|||||||
@ -7,8 +7,11 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace collective {
|
namespace collective {
|
||||||
|
|
||||||
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal)
|
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync)
|
||||||
: device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} {
|
: device_ordinal_{device_ordinal},
|
||||||
|
needs_sync_{needs_sync},
|
||||||
|
world_size_{GetWorldSize()},
|
||||||
|
rank_{GetRank()} {
|
||||||
if (device_ordinal_ < 0) {
|
if (device_ordinal_ < 0) {
|
||||||
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
|
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
|
||||||
}
|
}
|
||||||
@ -140,6 +143,9 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
|
|||||||
// First gather data from all the workers.
|
// First gather data from all the workers.
|
||||||
dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
|
dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
|
||||||
nccl_comm_, cuda_stream_));
|
nccl_comm_, cuda_stream_));
|
||||||
|
if (needs_sync_) {
|
||||||
|
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
|
||||||
|
}
|
||||||
|
|
||||||
// Then reduce locally.
|
// Then reduce locally.
|
||||||
auto *out_buffer = static_cast<char *>(send_receive_buffer);
|
auto *out_buffer = static_cast<char *>(send_receive_buffer);
|
||||||
|
|||||||
@ -12,7 +12,20 @@ namespace collective {
|
|||||||
|
|
||||||
class NcclDeviceCommunicator : public DeviceCommunicator {
|
class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||||
public:
|
public:
|
||||||
explicit NcclDeviceCommunicator(int device_ordinal);
|
/**
|
||||||
|
* @brief Construct a new NCCL communicator.
|
||||||
|
* @param device_ordinal The GPU device id.
|
||||||
|
* @param needs_sync Whether extra CUDA stream synchronization is needed.
|
||||||
|
*
|
||||||
|
* In multi-GPU tests when multiple NCCL communicators are created in the same process, sometimes
|
||||||
|
* a deadlock happens because NCCL kernels are blocking. The extra CUDA stream synchronization
|
||||||
|
* makes sure that the NCCL kernels are caught up, thus avoiding the deadlock.
|
||||||
|
*
|
||||||
|
* The Rabit communicator runs with one process per GPU, so the additional synchronization is not
|
||||||
|
* needed. The in-memory communicator is used in tests with multiple threads, each thread
|
||||||
|
* representing a rank/worker, so the additional synchronization is needed to avoid deadlocks.
|
||||||
|
*/
|
||||||
|
explicit NcclDeviceCommunicator(int device_ordinal, bool needs_sync);
|
||||||
~NcclDeviceCommunicator() override;
|
~NcclDeviceCommunicator() override;
|
||||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||||
Operation op) override;
|
Operation op) override;
|
||||||
@ -60,6 +73,7 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
|||||||
Operation op);
|
Operation op);
|
||||||
|
|
||||||
int const device_ordinal_;
|
int const device_ordinal_;
|
||||||
|
bool const needs_sync_;
|
||||||
int const world_size_;
|
int const world_size_;
|
||||||
int const rank_;
|
int const rank_;
|
||||||
ncclComm_t nccl_comm_{};
|
ncclComm_t nccl_comm_{};
|
||||||
|
|||||||
@ -16,7 +16,7 @@ namespace xgboost {
|
|||||||
namespace collective {
|
namespace collective {
|
||||||
|
|
||||||
TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) {
|
TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) {
|
||||||
auto construct = []() { NcclDeviceCommunicator comm{-1}; };
|
auto construct = []() { NcclDeviceCommunicator comm{-1, false}; };
|
||||||
EXPECT_THROW(construct(), dmlc::Error);
|
EXPECT_THROW(construct(), dmlc::Error);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user