Fix running MGPU gtests (#9200)
This commit is contained in:
parent
5d99b441d5
commit
acd363033e
@ -12,19 +12,22 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace collective {
|
namespace collective {
|
||||||
|
|
||||||
thread_local int Communicator::device_ordinal_{-1};
|
|
||||||
thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicator_{};
|
thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicator_{};
|
||||||
|
|
||||||
void Communicator::Finalize() {
|
void Communicator::Finalize() {
|
||||||
communicator_->Shutdown();
|
communicator_->Shutdown();
|
||||||
communicator_.reset(new NoOpCommunicator());
|
communicator_.reset(new NoOpCommunicator());
|
||||||
device_ordinal_ = -1;
|
|
||||||
device_communicator_.reset(nullptr);
|
device_communicator_.reset(nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
|
DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
|
||||||
if (!device_communicator_ || device_ordinal_ != device_ordinal) {
|
thread_local auto old_device_ordinal = -1;
|
||||||
device_ordinal_ = device_ordinal;
|
// If the number of GPUs changes, we need to re-initialize NCCL.
|
||||||
|
thread_local auto old_world_size = -1;
|
||||||
|
if (!device_communicator_ || device_ordinal != old_device_ordinal ||
|
||||||
|
communicator_->GetWorldSize() != old_world_size) {
|
||||||
|
old_device_ordinal = device_ordinal;
|
||||||
|
old_world_size = communicator_->GetWorldSize();
|
||||||
#ifdef XGBOOST_USE_NCCL
|
#ifdef XGBOOST_USE_NCCL
|
||||||
if (type_ != CommunicatorType::kFederated) {
|
if (type_ != CommunicatorType::kFederated) {
|
||||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, Get()));
|
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, Get()));
|
||||||
|
|||||||
@ -229,7 +229,6 @@ class Communicator {
|
|||||||
static thread_local std::unique_ptr<Communicator> communicator_;
|
static thread_local std::unique_ptr<Communicator> communicator_;
|
||||||
static thread_local CommunicatorType type_;
|
static thread_local CommunicatorType type_;
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
static thread_local int device_ordinal_;
|
|
||||||
static thread_local std::unique_ptr<DeviceCommunicator> device_communicator_;
|
static thread_local std::unique_ptr<DeviceCommunicator> device_communicator_;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user