Add 'sycl' devices to the context (#9691)
Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
committed by
GitHub
parent
d4d7097acc
commit
f41a08fda8
@@ -22,6 +22,9 @@ struct CUDAContext;
|
||||
struct DeviceSym {
|
||||
static auto constexpr CPU() { return "cpu"; }
|
||||
static auto constexpr CUDA() { return "cuda"; }
|
||||
static auto constexpr SyclDefault() { return "sycl"; }
|
||||
static auto constexpr SyclCPU() { return "sycl:cpu"; }
|
||||
static auto constexpr SyclGPU() { return "sycl:gpu"; }
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -33,12 +36,19 @@ struct DeviceOrd {
|
||||
static bst_d_ordinal_t constexpr CPUOrdinal() { return -1; }
|
||||
static bst_d_ordinal_t constexpr InvalidOrdinal() { return -2; }
|
||||
|
||||
enum Type : std::int16_t { kCPU = 0, kCUDA = 1 } device{kCPU};
|
||||
// CUDA device ordinal.
|
||||
enum Type : std::int16_t { kCPU = 0, kCUDA = 1,
|
||||
kSyclDefault = 2, kSyclCPU = 3, kSyclGPU = 4} device{kCPU};
|
||||
// CUDA or Sycl device ordinal.
|
||||
bst_d_ordinal_t ordinal{CPUOrdinal()};
|
||||
|
||||
[[nodiscard]] bool IsCUDA() const { return device == kCUDA; }
|
||||
[[nodiscard]] bool IsCPU() const { return device == kCPU; }
|
||||
[[nodiscard]] bool IsSyclDefault() const { return device == kSyclDefault; }
|
||||
[[nodiscard]] bool IsSyclCPU() const { return device == kSyclCPU; }
|
||||
[[nodiscard]] bool IsSyclGPU() const { return device == kSyclGPU; }
|
||||
[[nodiscard]] bool IsSycl() const { return (IsSyclDefault() ||
|
||||
IsSyclCPU() ||
|
||||
IsSyclGPU()); }
|
||||
|
||||
constexpr DeviceOrd() = default;
|
||||
constexpr DeviceOrd(Type type, bst_d_ordinal_t ord) : device{type}, ordinal{ord} {}
|
||||
@@ -60,6 +70,31 @@ struct DeviceOrd {
|
||||
[[nodiscard]] static constexpr auto CUDA(bst_d_ordinal_t ordinal) {
|
||||
return DeviceOrd{kCUDA, ordinal};
|
||||
}
|
||||
/**
|
||||
* @brief Constructor for SYCL.
|
||||
*
|
||||
* @param ordinal SYCL device ordinal.
|
||||
*/
|
||||
[[nodiscard]] constexpr static auto SyclDefault(bst_d_ordinal_t ordinal = -1) {
|
||||
return DeviceOrd{kSyclDefault, ordinal};
|
||||
}
|
||||
/**
|
||||
* @brief Constructor for SYCL CPU.
|
||||
*
|
||||
* @param ordinal SYCL CPU device ordinal.
|
||||
*/
|
||||
[[nodiscard]] constexpr static auto SyclCPU(bst_d_ordinal_t ordinal = -1) {
|
||||
return DeviceOrd{kSyclCPU, ordinal};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Constructor for SYCL GPU.
|
||||
*
|
||||
* @param ordinal SYCL GPU device ordinal.
|
||||
*/
|
||||
[[nodiscard]] constexpr static auto SyclGPU(bst_d_ordinal_t ordinal = -1) {
|
||||
return DeviceOrd{kSyclGPU, ordinal};
|
||||
}
|
||||
|
||||
[[nodiscard]] bool operator==(DeviceOrd const& that) const {
|
||||
return device == that.device && ordinal == that.ordinal;
|
||||
@@ -74,6 +109,12 @@ struct DeviceOrd {
|
||||
return DeviceSym::CPU();
|
||||
case DeviceOrd::kCUDA:
|
||||
return DeviceSym::CUDA() + (':' + std::to_string(ordinal));
|
||||
case DeviceOrd::kSyclDefault:
|
||||
return DeviceSym::SyclDefault() + (':' + std::to_string(ordinal));
|
||||
case DeviceOrd::kSyclCPU:
|
||||
return DeviceSym::SyclCPU() + (':' + std::to_string(ordinal));
|
||||
case DeviceOrd::kSyclGPU:
|
||||
return DeviceSym::SyclGPU() + (':' + std::to_string(ordinal));
|
||||
default: {
|
||||
LOG(FATAL) << "Unknown device.";
|
||||
return "";
|
||||
@@ -142,6 +183,25 @@ struct Context : public XGBoostParameter<Context> {
|
||||
* @brief Is XGBoost running on a CUDA device?
|
||||
*/
|
||||
[[nodiscard]] bool IsCUDA() const { return Device().IsCUDA(); }
|
||||
/**
|
||||
* @brief Is XGBoost running on the default SYCL device?
|
||||
*/
|
||||
[[nodiscard]] bool IsSyclDefault() const { return Device().IsSyclDefault(); }
|
||||
/**
|
||||
* @brief Is XGBoost running on a SYCL CPU?
|
||||
*/
|
||||
[[nodiscard]] bool IsSyclCPU() const { return Device().IsSyclCPU(); }
|
||||
/**
|
||||
* @brief Is XGBoost running on a SYCL GPU?
|
||||
*/
|
||||
[[nodiscard]] bool IsSyclGPU() const { return Device().IsSyclGPU(); }
|
||||
/**
|
||||
* @brief Is XGBoost running on any SYCL device?
|
||||
*/
|
||||
[[nodiscard]] bool IsSycl() const { return IsSyclDefault()
|
||||
|| IsSyclCPU()
|
||||
|| IsSyclGPU(); }
|
||||
|
||||
/**
|
||||
* @brief Get the current device and ordinal.
|
||||
*/
|
||||
@@ -175,6 +235,7 @@ struct Context : public XGBoostParameter<Context> {
|
||||
Context ctx = *this;
|
||||
return ctx.SetDevice(DeviceOrd::CPU());
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Call function based on the current device.
|
||||
*/
|
||||
@@ -196,6 +257,20 @@ struct Context : public XGBoostParameter<Context> {
|
||||
return std::invoke_result_t<CPUFn>();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Call function for sycl devices
|
||||
*/
|
||||
template <typename CPUFn, typename CUDAFn, typename SYCLFn>
|
||||
decltype(auto) DispatchDevice(CPUFn&& cpu_fn, CUDAFn&& cuda_fn, SYCLFn&& sycl_fn) const {
|
||||
static_assert(std::is_same_v<std::invoke_result_t<CPUFn>, std::invoke_result_t<CUDAFn>>);
|
||||
static_assert(std::is_same_v<std::invoke_result_t<CPUFn>, std::invoke_result_t<SYCLFn>>);
|
||||
if (this->Device().IsSycl()) {
|
||||
return sycl_fn();
|
||||
} else {
|
||||
return DispatchDevice(cpu_fn, cuda_fn);
|
||||
}
|
||||
}
|
||||
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(Context) {
|
||||
DMLC_DECLARE_FIELD(seed)
|
||||
|
||||
Reference in New Issue
Block a user