Fix running MGPU gtests (#9200)
This commit is contained in:
parent
5d99b441d5
commit
acd363033e
@ -12,19 +12,22 @@
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
thread_local int Communicator::device_ordinal_{-1};
|
||||
thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicator_{};
|
||||
|
||||
void Communicator::Finalize() {
|
||||
communicator_->Shutdown();
|
||||
communicator_.reset(new NoOpCommunicator());
|
||||
device_ordinal_ = -1;
|
||||
device_communicator_.reset(nullptr);
|
||||
}
|
||||
|
||||
DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
|
||||
if (!device_communicator_ || device_ordinal_ != device_ordinal) {
|
||||
device_ordinal_ = device_ordinal;
|
||||
thread_local auto old_device_ordinal = -1;
|
||||
// If the number of GPUs changes, we need to re-initialize NCCL.
|
||||
thread_local auto old_world_size = -1;
|
||||
if (!device_communicator_ || device_ordinal != old_device_ordinal ||
|
||||
communicator_->GetWorldSize() != old_world_size) {
|
||||
old_device_ordinal = device_ordinal;
|
||||
old_world_size = communicator_->GetWorldSize();
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
if (type_ != CommunicatorType::kFederated) {
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, Get()));
|
||||
|
||||
@ -229,7 +229,6 @@ class Communicator {
|
||||
static thread_local std::unique_ptr<Communicator> communicator_;
|
||||
static thread_local CommunicatorType type_;
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
static thread_local int device_ordinal_;
|
||||
static thread_local std::unique_ptr<DeviceCommunicator> device_communicator_;
|
||||
#endif
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user