Fix NCCL test hang (#9367)

This commit is contained in:
Rong Ou 2023-07-06 20:21:35 -07:00 committed by GitHub
parent 41c6813496
commit 15ca12a77e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 36 additions and 8 deletions

View File

@ -29,10 +29,18 @@ DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
old_device_ordinal = device_ordinal; old_device_ordinal = device_ordinal;
old_world_size = communicator_->GetWorldSize(); old_world_size = communicator_->GetWorldSize();
#ifdef XGBOOST_USE_NCCL #ifdef XGBOOST_USE_NCCL
if (type_ != CommunicatorType::kFederated) { switch (type_) {
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal)); case CommunicatorType::kRabit:
} else { device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal)); break;
case CommunicatorType::kFederated:
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
break;
case CommunicatorType::kInMemory:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true));
break;
default:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
} }
#else #else
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal)); device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));

View File

@ -7,8 +7,11 @@
namespace xgboost { namespace xgboost {
namespace collective { namespace collective {
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal) NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync)
: device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} { : device_ordinal_{device_ordinal},
needs_sync_{needs_sync},
world_size_{GetWorldSize()},
rank_{GetRank()} {
if (device_ordinal_ < 0) { if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_; LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
} }
@ -140,6 +143,9 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
// First gather data from all the workers. // First gather data from all the workers.
dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type), dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
nccl_comm_, cuda_stream_)); nccl_comm_, cuda_stream_));
if (needs_sync_) {
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
}
// Then reduce locally. // Then reduce locally.
auto *out_buffer = static_cast<char *>(send_receive_buffer); auto *out_buffer = static_cast<char *>(send_receive_buffer);

View File

@ -12,7 +12,20 @@ namespace collective {
class NcclDeviceCommunicator : public DeviceCommunicator { class NcclDeviceCommunicator : public DeviceCommunicator {
public: public:
explicit NcclDeviceCommunicator(int device_ordinal); /**
* @brief Construct a new NCCL communicator.
* @param device_ordinal The GPU device id.
* @param needs_sync Whether extra CUDA stream synchronization is needed.
*
* In multi-GPU tests when multiple NCCL communicators are created in the same process, sometimes
* a deadlock happens because NCCL kernels are blocking. The extra CUDA stream synchronization
* makes sure that the NCCL kernels are caught up, thus avoiding the deadlock.
*
* The Rabit communicator runs with one process per GPU, so the additional synchronization is not
* needed. The in-memory communicator is used in tests with multiple threads, each thread
* representing a rank/worker, so the additional synchronization is needed to avoid deadlocks.
*/
explicit NcclDeviceCommunicator(int device_ordinal, bool needs_sync);
~NcclDeviceCommunicator() override; ~NcclDeviceCommunicator() override;
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override; Operation op) override;
@ -60,6 +73,7 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
Operation op); Operation op);
int const device_ordinal_; int const device_ordinal_;
bool const needs_sync_;
int const world_size_; int const world_size_;
int const rank_; int const rank_;
ncclComm_t nccl_comm_{}; ncclComm_t nccl_comm_{};

View File

@ -16,7 +16,7 @@ namespace xgboost {
namespace collective { namespace collective {
TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) { TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) {
auto construct = []() { NcclDeviceCommunicator comm{-1}; }; auto construct = []() { NcclDeviceCommunicator comm{-1, false}; };
EXPECT_THROW(construct(), dmlc::Error); EXPECT_THROW(construct(), dmlc::Error);
} }