54 lines
1.7 KiB
Plaintext
54 lines
1.7 KiB
Plaintext
/*!
|
|
* Copyright 2022 XGBoost contributors
|
|
*/
|
|
#include "communicator.h"
|
|
#include "device_communicator.cuh"
|
|
#include "device_communicator_adapter.cuh"
|
|
#include "noop_communicator.h"
|
|
#ifdef XGBOOST_USE_NCCL
|
|
#include "nccl_device_communicator.cuh"
|
|
#endif
|
|
|
|
namespace xgboost {
|
|
namespace collective {
|
|
|
|
thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicator_{};
|
|
|
|
void Communicator::Finalize() {
|
|
communicator_->Shutdown();
|
|
communicator_.reset(new NoOpCommunicator());
|
|
device_communicator_.reset(nullptr);
|
|
}
|
|
|
|
DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
|
|
thread_local auto old_device_ordinal = -1;
|
|
// If the number of GPUs changes, we need to re-initialize NCCL.
|
|
thread_local auto old_world_size = -1;
|
|
if (!device_communicator_ || device_ordinal != old_device_ordinal ||
|
|
communicator_->GetWorldSize() != old_world_size) {
|
|
old_device_ordinal = device_ordinal;
|
|
old_world_size = communicator_->GetWorldSize();
|
|
#ifdef XGBOOST_USE_NCCL
|
|
switch (type_) {
|
|
case CommunicatorType::kRabit:
|
|
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
|
|
break;
|
|
case CommunicatorType::kFederated:
|
|
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
|
break;
|
|
case CommunicatorType::kInMemory:
|
|
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true));
|
|
break;
|
|
default:
|
|
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
|
|
}
|
|
#else
|
|
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
|
#endif
|
|
}
|
|
return device_communicator_.get();
|
|
}
|
|
|
|
} // namespace collective
|
|
} // namespace xgboost
|