xgboost/src/collective/communicator.cu
2023-07-07 11:21:35 +08:00

54 lines
1.7 KiB
Plaintext

/*!
* Copyright 2022 XGBoost contributors
*/
#include "communicator.h"
#include "device_communicator.cuh"
#include "device_communicator_adapter.cuh"
#include "noop_communicator.h"
#ifdef XGBOOST_USE_NCCL
#include "nccl_device_communicator.cuh"
#endif
namespace xgboost {
namespace collective {
thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicator_{};
void Communicator::Finalize() {
communicator_->Shutdown();
communicator_.reset(new NoOpCommunicator());
device_communicator_.reset(nullptr);
}
DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
thread_local auto old_device_ordinal = -1;
// If the number of GPUs changes, we need to re-initialize NCCL.
thread_local auto old_world_size = -1;
if (!device_communicator_ || device_ordinal != old_device_ordinal ||
communicator_->GetWorldSize() != old_world_size) {
old_device_ordinal = device_ordinal;
old_world_size = communicator_->GetWorldSize();
#ifdef XGBOOST_USE_NCCL
switch (type_) {
case CommunicatorType::kRabit:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
break;
case CommunicatorType::kFederated:
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
break;
case CommunicatorType::kInMemory:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true));
break;
default:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
}
#else
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
#endif
}
return device_communicator_.get();
}
} // namespace collective
} // namespace xgboost