xgboost/plugin/sycl/device_manager.cc
Jiaming Yuan a5a58102e5
Revamp the rabit implementation. (#10112)
This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
2024-05-20 11:56:23 +08:00

128 lines
5.0 KiB
C++

/*!
* Copyright 2017-2023 by Contributors
* \file device_manager.cc
*/
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#pragma GCC diagnostic pop
#include "../sycl/device_manager.h"
#include "../../src/collective/communicator-inl.h"
namespace xgboost {
namespace sycl {
::sycl::device DeviceManager::GetDevice(const DeviceOrd& device_spec) const {
if (!device_spec.IsSycl()) {
LOG(WARNING) << "Sycl kernel is executed with non-sycl context: "
<< device_spec.Name() << ". "
<< "Default sycl device_selector will be used.";
}
bool not_use_default_selector = (device_spec.ordinal != kDefaultOrdinal) ||
(collective::IsDistributed());
if (not_use_default_selector) {
DeviceRegister& device_register = GetDevicesRegister();
const int device_idx =
collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal;
if (device_spec.IsSyclDefault()) {
auto& devices = device_register.devices;
CHECK_LT(device_idx, devices.size());
return devices[device_idx];
} else if (device_spec.IsSyclCPU()) {
auto& cpu_devices = device_register.cpu_devices;
CHECK_LT(device_idx, cpu_devices.size());
return cpu_devices[device_idx];
} else {
auto& gpu_devices = device_register.gpu_devices;
CHECK_LT(device_idx, gpu_devices.size());
return gpu_devices[device_idx];
}
} else {
if (device_spec.IsSyclCPU()) {
return ::sycl::device(::sycl::cpu_selector_v);
} else if (device_spec.IsSyclGPU()) {
return ::sycl::device(::sycl::gpu_selector_v);
} else {
return ::sycl::device(::sycl::default_selector_v);
}
}
}
::sycl::queue DeviceManager::GetQueue(const DeviceOrd& device_spec) const {
if (!device_spec.IsSycl()) {
LOG(WARNING) << "Sycl kernel is executed with non-sycl context: "
<< device_spec.Name() << ". "
<< "Default sycl device_selector will be used.";
}
QueueRegister_t& queue_register = GetQueueRegister();
if (queue_register.count(device_spec.Name()) > 0) {
return queue_register.at(device_spec.Name());
}
bool not_use_default_selector = (device_spec.ordinal != kDefaultOrdinal) ||
(collective::IsDistributed());
std::lock_guard<std::mutex> guard(queue_registering_mutex);
if (not_use_default_selector) {
DeviceRegister& device_register = GetDevicesRegister();
const int device_idx =
collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal;
if (device_spec.IsSyclDefault()) {
auto& devices = device_register.devices;
CHECK_LT(device_idx, devices.size());
queue_register[device_spec.Name()] = ::sycl::queue(devices[device_idx]);
} else if (device_spec.IsSyclCPU()) {
auto& cpu_devices = device_register.cpu_devices;
CHECK_LT(device_idx, cpu_devices.size());
queue_register[device_spec.Name()] = ::sycl::queue(cpu_devices[device_idx]);
} else if (device_spec.IsSyclGPU()) {
auto& gpu_devices = device_register.gpu_devices;
CHECK_LT(device_idx, gpu_devices.size());
queue_register[device_spec.Name()] = ::sycl::queue(gpu_devices[device_idx]);
}
} else {
if (device_spec.IsSyclCPU()) {
queue_register[device_spec.Name()] = ::sycl::queue(::sycl::cpu_selector_v);
} else if (device_spec.IsSyclGPU()) {
queue_register[device_spec.Name()] = ::sycl::queue(::sycl::gpu_selector_v);
} else {
queue_register[device_spec.Name()] = ::sycl::queue(::sycl::default_selector_v);
}
}
return queue_register.at(device_spec.Name());
}
DeviceManager::DeviceRegister& DeviceManager::GetDevicesRegister() const {
static DeviceRegister device_register;
if (device_register.devices.size() == 0) {
std::lock_guard<std::mutex> guard(device_registering_mutex);
std::vector<::sycl::device> devices = ::sycl::device::get_devices();
for (size_t i = 0; i < devices.size(); i++) {
LOG(INFO) << "device_index = " << i << ", name = "
<< devices[i].get_info<::sycl::info::device::name>();
}
for (size_t i = 0; i < devices.size(); i++) {
device_register.devices.push_back(devices[i]);
if (devices[i].is_cpu()) {
device_register.cpu_devices.push_back(devices[i]);
} else if (devices[i].is_gpu()) {
device_register.gpu_devices.push_back(devices[i]);
}
}
}
return device_register;
}
DeviceManager::QueueRegister_t& DeviceManager::GetQueueRegister() const {
static QueueRegister_t queue_register;
return queue_register;
}
} // namespace sycl
} // namespace xgboost