Add 'sycl' devices to the context (#9691)
Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
committed by
GitHub
parent
d4d7097acc
commit
f41a08fda8
@@ -104,19 +104,32 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) {
|
||||
// mingw hangs on regex using rtools 430. Basic checks only.
|
||||
CHECK_GE(input.size(), 3) << msg;
|
||||
auto substr = input.substr(0, 3);
|
||||
bool valid = substr == "cpu" || substr == "cud" || substr == "gpu";
|
||||
bool valid = substr == "cpu" || substr == "cud" || substr == "gpu" || substr == "syc";
|
||||
CHECK(valid) << msg;
|
||||
#else
|
||||
std::regex pattern{"gpu(:[0-9]+)?|cuda(:[0-9]+)?|cpu"};
|
||||
std::regex pattern{"gpu(:[0-9]+)?|cuda(:[0-9]+)?|cpu|sycl(:cpu|:gpu)?(:-1|:[0-9]+)?"};
|
||||
if (!std::regex_match(input, pattern)) {
|
||||
fatal();
|
||||
}
|
||||
#endif // defined(__MINGW32__)
|
||||
|
||||
// handle alias
|
||||
std::string s_device = std::regex_replace(input, std::regex{"gpu"}, DeviceSym::CUDA());
|
||||
std::string s_device = input;
|
||||
if (!std::regex_match(s_device, std::regex("sycl(:cpu|:gpu)?(:-1|:[0-9]+)?"))) {
|
||||
s_device = std::regex_replace(s_device, std::regex{"gpu"}, DeviceSym::CUDA());
|
||||
}
|
||||
|
||||
auto split_it = std::find(s_device.cbegin(), s_device.cend(), ':');
|
||||
if (std::regex_match(s_device, std::regex("sycl:(cpu|gpu)?"))) split_it = s_device.cend();
|
||||
|
||||
// For s_device like "sycl:gpu:1"
|
||||
if (split_it != s_device.cend()) {
|
||||
auto second_split_it = std::find(split_it + 1, s_device.cend(), ':');
|
||||
if (second_split_it != s_device.cend()) {
|
||||
split_it = second_split_it;
|
||||
}
|
||||
}
|
||||
|
||||
DeviceOrd device;
|
||||
device.ordinal = DeviceOrd::InvalidOrdinal(); // mark it invalid for check.
|
||||
if (split_it == s_device.cend()) {
|
||||
@@ -125,15 +138,22 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) {
|
||||
device = DeviceOrd::CPU();
|
||||
} else if (s_device == DeviceSym::CUDA()) {
|
||||
device = DeviceOrd::CUDA(0); // use 0 as default;
|
||||
} else if (s_device == DeviceSym::SyclDefault()) {
|
||||
device = DeviceOrd::SyclDefault();
|
||||
} else if (s_device == DeviceSym::SyclCPU()) {
|
||||
device = DeviceOrd::SyclCPU();
|
||||
} else if (s_device == DeviceSym::SyclGPU()) {
|
||||
device = DeviceOrd::SyclGPU();
|
||||
} else {
|
||||
fatal();
|
||||
}
|
||||
} else {
|
||||
// must be CUDA when ordinal is specifed.
|
||||
// must be CUDA or SYCL when ordinal is specifed.
|
||||
// +1 for colon
|
||||
std::size_t offset = std::distance(s_device.cbegin(), split_it) + 1;
|
||||
// substr
|
||||
StringView s_ordinal = {s_device.data() + offset, s_device.size() - offset};
|
||||
StringView s_type = {s_device.data(), offset - 1};
|
||||
if (s_ordinal.empty()) {
|
||||
fatal();
|
||||
}
|
||||
@@ -143,13 +163,23 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) {
|
||||
}
|
||||
CHECK_LE(opt_id.value(), std::numeric_limits<bst_d_ordinal_t>::max())
|
||||
<< "Ordinal value too large.";
|
||||
device = DeviceOrd::CUDA(opt_id.value());
|
||||
if (s_type == DeviceSym::SyclDefault()) {
|
||||
device = DeviceOrd::SyclDefault(opt_id.value());
|
||||
} else if (s_type == DeviceSym::SyclCPU()) {
|
||||
device = DeviceOrd::SyclCPU(opt_id.value());
|
||||
} else if (s_type == DeviceSym::SyclGPU()) {
|
||||
device = DeviceOrd::SyclGPU(opt_id.value());
|
||||
} else {
|
||||
device = DeviceOrd::CUDA(opt_id.value());
|
||||
}
|
||||
}
|
||||
|
||||
if (device.ordinal < DeviceOrd::CPUOrdinal()) {
|
||||
fatal();
|
||||
}
|
||||
device = CUDAOrdinal(device, fail_on_invalid_gpu_id);
|
||||
if (device.IsCUDA()) {
|
||||
device = CUDAOrdinal(device, fail_on_invalid_gpu_id);
|
||||
}
|
||||
|
||||
return device;
|
||||
}
|
||||
@@ -216,7 +246,7 @@ void Context::SetDeviceOrdinal(Args const& kwargs) {
|
||||
|
||||
if (this->IsCPU()) {
|
||||
CHECK_EQ(this->device_.ordinal, DeviceOrd::CPUOrdinal());
|
||||
} else {
|
||||
} else if (this->IsCUDA()) {
|
||||
CHECK_GT(this->device_.ordinal, DeviceOrd::CPUOrdinal());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user