|
|
|
|
@@ -5,11 +5,12 @@
|
|
|
|
|
#pragma GCC diagnostic push
|
|
|
|
|
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
|
|
|
|
#pragma GCC diagnostic ignored "-W#pragma-messages"
|
|
|
|
|
#include <rabit/rabit.h>
|
|
|
|
|
#pragma GCC diagnostic pop
|
|
|
|
|
|
|
|
|
|
#include "../sycl/device_manager.h"
|
|
|
|
|
|
|
|
|
|
#include "../../src/collective/communicator-inl.h"
|
|
|
|
|
|
|
|
|
|
namespace xgboost {
|
|
|
|
|
namespace sycl {
|
|
|
|
|
|
|
|
|
|
@@ -21,22 +22,23 @@ namespace sycl {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool not_use_default_selector = (device_spec.ordinal != kDefaultOrdinal) ||
|
|
|
|
|
(rabit::IsDistributed());
|
|
|
|
|
(collective::IsDistributed());
|
|
|
|
|
if (not_use_default_selector) {
|
|
|
|
|
DeviceRegister& device_register = GetDevicesRegister();
|
|
|
|
|
const int device_idx = rabit::IsDistributed() ? rabit::GetRank() : device_spec.ordinal;
|
|
|
|
|
const int device_idx =
|
|
|
|
|
collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal;
|
|
|
|
|
if (device_spec.IsSyclDefault()) {
|
|
|
|
|
auto& devices = device_register.devices;
|
|
|
|
|
CHECK_LT(device_idx, devices.size());
|
|
|
|
|
return devices[device_idx];
|
|
|
|
|
auto& devices = device_register.devices;
|
|
|
|
|
CHECK_LT(device_idx, devices.size());
|
|
|
|
|
return devices[device_idx];
|
|
|
|
|
} else if (device_spec.IsSyclCPU()) {
|
|
|
|
|
auto& cpu_devices = device_register.cpu_devices;
|
|
|
|
|
CHECK_LT(device_idx, cpu_devices.size());
|
|
|
|
|
return cpu_devices[device_idx];
|
|
|
|
|
auto& cpu_devices = device_register.cpu_devices;
|
|
|
|
|
CHECK_LT(device_idx, cpu_devices.size());
|
|
|
|
|
return cpu_devices[device_idx];
|
|
|
|
|
} else {
|
|
|
|
|
auto& gpu_devices = device_register.gpu_devices;
|
|
|
|
|
CHECK_LT(device_idx, gpu_devices.size());
|
|
|
|
|
return gpu_devices[device_idx];
|
|
|
|
|
auto& gpu_devices = device_register.gpu_devices;
|
|
|
|
|
CHECK_LT(device_idx, gpu_devices.size());
|
|
|
|
|
return gpu_devices[device_idx];
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (device_spec.IsSyclCPU()) {
|
|
|
|
|
@@ -62,24 +64,25 @@ namespace sycl {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool not_use_default_selector = (device_spec.ordinal != kDefaultOrdinal) ||
|
|
|
|
|
(rabit::IsDistributed());
|
|
|
|
|
(collective::IsDistributed());
|
|
|
|
|
std::lock_guard<std::mutex> guard(queue_registering_mutex);
|
|
|
|
|
if (not_use_default_selector) {
|
|
|
|
|
DeviceRegister& device_register = GetDevicesRegister();
|
|
|
|
|
const int device_idx = rabit::IsDistributed() ? rabit::GetRank() : device_spec.ordinal;
|
|
|
|
|
if (device_spec.IsSyclDefault()) {
|
|
|
|
|
auto& devices = device_register.devices;
|
|
|
|
|
CHECK_LT(device_idx, devices.size());
|
|
|
|
|
queue_register[device_spec.Name()] = ::sycl::queue(devices[device_idx]);
|
|
|
|
|
} else if (device_spec.IsSyclCPU()) {
|
|
|
|
|
auto& cpu_devices = device_register.cpu_devices;
|
|
|
|
|
CHECK_LT(device_idx, cpu_devices.size());
|
|
|
|
|
queue_register[device_spec.Name()] = ::sycl::queue(cpu_devices[device_idx]);;
|
|
|
|
|
} else if (device_spec.IsSyclGPU()) {
|
|
|
|
|
auto& gpu_devices = device_register.gpu_devices;
|
|
|
|
|
CHECK_LT(device_idx, gpu_devices.size());
|
|
|
|
|
queue_register[device_spec.Name()] = ::sycl::queue(gpu_devices[device_idx]);
|
|
|
|
|
}
|
|
|
|
|
DeviceRegister& device_register = GetDevicesRegister();
|
|
|
|
|
const int device_idx =
|
|
|
|
|
collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal;
|
|
|
|
|
if (device_spec.IsSyclDefault()) {
|
|
|
|
|
auto& devices = device_register.devices;
|
|
|
|
|
CHECK_LT(device_idx, devices.size());
|
|
|
|
|
queue_register[device_spec.Name()] = ::sycl::queue(devices[device_idx]);
|
|
|
|
|
} else if (device_spec.IsSyclCPU()) {
|
|
|
|
|
auto& cpu_devices = device_register.cpu_devices;
|
|
|
|
|
CHECK_LT(device_idx, cpu_devices.size());
|
|
|
|
|
queue_register[device_spec.Name()] = ::sycl::queue(cpu_devices[device_idx]);
|
|
|
|
|
} else if (device_spec.IsSyclGPU()) {
|
|
|
|
|
auto& gpu_devices = device_register.gpu_devices;
|
|
|
|
|
CHECK_LT(device_idx, gpu_devices.size());
|
|
|
|
|
queue_register[device_spec.Name()] = ::sycl::queue(gpu_devices[device_idx]);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (device_spec.IsSyclCPU()) {
|
|
|
|
|
queue_register[device_spec.Name()] = ::sycl::queue(::sycl::cpu_selector_v);
|
|
|
|
|
|