Add 'sycl' devices to the context (#9691)

Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
Dmitry Razdoburdin
2023-10-26 16:17:56 +02:00
committed by GitHub
parent d4d7097acc
commit f41a08fda8
3 changed files with 207 additions and 9 deletions

View File

@@ -22,6 +22,9 @@ struct CUDAContext;
struct DeviceSym {
static auto constexpr CPU() { return "cpu"; }
static auto constexpr CUDA() { return "cuda"; }
static auto constexpr SyclDefault() { return "sycl"; }
static auto constexpr SyclCPU() { return "sycl:cpu"; }
static auto constexpr SyclGPU() { return "sycl:gpu"; }
};
/**
@@ -33,12 +36,19 @@ struct DeviceOrd {
static bst_d_ordinal_t constexpr CPUOrdinal() { return -1; }
static bst_d_ordinal_t constexpr InvalidOrdinal() { return -2; }
enum Type : std::int16_t { kCPU = 0, kCUDA = 1 } device{kCPU};
// CUDA device ordinal.
enum Type : std::int16_t { kCPU = 0, kCUDA = 1,
kSyclDefault = 2, kSyclCPU = 3, kSyclGPU = 4} device{kCPU};
// CUDA or Sycl device ordinal.
bst_d_ordinal_t ordinal{CPUOrdinal()};
[[nodiscard]] bool IsCUDA() const { return device == kCUDA; }
[[nodiscard]] bool IsCPU() const { return device == kCPU; }
[[nodiscard]] bool IsSyclDefault() const { return device == kSyclDefault; }
[[nodiscard]] bool IsSyclCPU() const { return device == kSyclCPU; }
[[nodiscard]] bool IsSyclGPU() const { return device == kSyclGPU; }
[[nodiscard]] bool IsSycl() const { return (IsSyclDefault() ||
IsSyclCPU() ||
IsSyclGPU()); }
constexpr DeviceOrd() = default;
constexpr DeviceOrd(Type type, bst_d_ordinal_t ord) : device{type}, ordinal{ord} {}
@@ -60,6 +70,31 @@ struct DeviceOrd {
[[nodiscard]] static constexpr auto CUDA(bst_d_ordinal_t ordinal) {
return DeviceOrd{kCUDA, ordinal};
}
/**
* @brief Constructor for SYCL.
*
* @param ordinal SYCL device ordinal.
*/
[[nodiscard]] constexpr static auto SyclDefault(bst_d_ordinal_t ordinal = -1) {
return DeviceOrd{kSyclDefault, ordinal};
}
/**
* @brief Constructor for SYCL CPU.
*
* @param ordinal SYCL CPU device ordinal.
*/
[[nodiscard]] constexpr static auto SyclCPU(bst_d_ordinal_t ordinal = -1) {
return DeviceOrd{kSyclCPU, ordinal};
}
/**
* @brief Constructor for SYCL GPU.
*
* @param ordinal SYCL GPU device ordinal.
*/
[[nodiscard]] constexpr static auto SyclGPU(bst_d_ordinal_t ordinal = -1) {
return DeviceOrd{kSyclGPU, ordinal};
}
[[nodiscard]] bool operator==(DeviceOrd const& that) const {
return device == that.device && ordinal == that.ordinal;
@@ -74,6 +109,12 @@ struct DeviceOrd {
return DeviceSym::CPU();
case DeviceOrd::kCUDA:
return DeviceSym::CUDA() + (':' + std::to_string(ordinal));
case DeviceOrd::kSyclDefault:
return DeviceSym::SyclDefault() + (':' + std::to_string(ordinal));
case DeviceOrd::kSyclCPU:
return DeviceSym::SyclCPU() + (':' + std::to_string(ordinal));
case DeviceOrd::kSyclGPU:
return DeviceSym::SyclGPU() + (':' + std::to_string(ordinal));
default: {
LOG(FATAL) << "Unknown device.";
return "";
@@ -142,6 +183,25 @@ struct Context : public XGBoostParameter<Context> {
* @brief Is XGBoost running on a CUDA device?
*/
[[nodiscard]] bool IsCUDA() const { return Device().IsCUDA(); }
/**
* @brief Is XGBoost running on the default SYCL device?
*/
[[nodiscard]] bool IsSyclDefault() const { return Device().IsSyclDefault(); }
/**
* @brief Is XGBoost running on a SYCL CPU?
*/
[[nodiscard]] bool IsSyclCPU() const { return Device().IsSyclCPU(); }
/**
* @brief Is XGBoost running on a SYCL GPU?
*/
[[nodiscard]] bool IsSyclGPU() const { return Device().IsSyclGPU(); }
/**
* @brief Is XGBoost running on any SYCL device?
*/
[[nodiscard]] bool IsSycl() const { return IsSyclDefault()
|| IsSyclCPU()
|| IsSyclGPU(); }
/**
* @brief Get the current device and ordinal.
*/
@@ -175,6 +235,7 @@ struct Context : public XGBoostParameter<Context> {
Context ctx = *this;
return ctx.SetDevice(DeviceOrd::CPU());
}
/**
* @brief Call function based on the current device.
*/
@@ -196,6 +257,20 @@ struct Context : public XGBoostParameter<Context> {
return std::invoke_result_t<CPUFn>();
}
/**
* @brief Call function for sycl devices
*/
template <typename CPUFn, typename CUDAFn, typename SYCLFn>
decltype(auto) DispatchDevice(CPUFn&& cpu_fn, CUDAFn&& cuda_fn, SYCLFn&& sycl_fn) const {
static_assert(std::is_same_v<std::invoke_result_t<CPUFn>, std::invoke_result_t<CUDAFn>>);
static_assert(std::is_same_v<std::invoke_result_t<CPUFn>, std::invoke_result_t<SYCLFn>>);
if (this->Device().IsSycl()) {
return sycl_fn();
} else {
return DispatchDevice(cpu_fn, cuda_fn);
}
}
// declare parameters
DMLC_DECLARE_PARAMETER(Context) {
DMLC_DECLARE_FIELD(seed)