From acd363033e1127898a1e1fec7db0efc37fc2a98e Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 25 May 2023 14:26:38 -0700 Subject: [PATCH] Fix running MGPU gtests (#9200) --- src/collective/communicator.cu | 11 +++++++---- src/collective/communicator.h | 1 - 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/collective/communicator.cu b/src/collective/communicator.cu index 0880741f9..8bd10382d 100644 --- a/src/collective/communicator.cu +++ b/src/collective/communicator.cu @@ -12,19 +12,22 @@ namespace xgboost { namespace collective { -thread_local int Communicator::device_ordinal_{-1}; thread_local std::unique_ptr Communicator::device_communicator_{}; void Communicator::Finalize() { communicator_->Shutdown(); communicator_.reset(new NoOpCommunicator()); - device_ordinal_ = -1; device_communicator_.reset(nullptr); } DeviceCommunicator* Communicator::GetDevice(int device_ordinal) { - if (!device_communicator_ || device_ordinal_ != device_ordinal) { - device_ordinal_ = device_ordinal; + thread_local auto old_device_ordinal = -1; + // If the number of GPUs changes, we need to re-initialize NCCL. + thread_local auto old_world_size = -1; + if (!device_communicator_ || device_ordinal != old_device_ordinal || + communicator_->GetWorldSize() != old_world_size) { + old_device_ordinal = device_ordinal; + old_world_size = communicator_->GetWorldSize(); #ifdef XGBOOST_USE_NCCL if (type_ != CommunicatorType::kFederated) { device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, Get())); diff --git a/src/collective/communicator.h b/src/collective/communicator.h index 885a8d438..6cda5e47c 100644 --- a/src/collective/communicator.h +++ b/src/collective/communicator.h @@ -229,7 +229,6 @@ class Communicator { static thread_local std::unique_ptr communicator_; static thread_local CommunicatorType type_; #if defined(XGBOOST_USE_CUDA) - static thread_local int device_ordinal_; static thread_local std::unique_ptr device_communicator_; #endif