diff --git a/include/xgboost/context.h b/include/xgboost/context.h index 7748db9f9..6745bcb60 100644 --- a/include/xgboost/context.h +++ b/include/xgboost/context.h @@ -22,6 +22,9 @@ struct CUDAContext; struct DeviceSym { static auto constexpr CPU() { return "cpu"; } static auto constexpr CUDA() { return "cuda"; } + static auto constexpr SyclDefault() { return "sycl"; } + static auto constexpr SyclCPU() { return "sycl:cpu"; } + static auto constexpr SyclGPU() { return "sycl:gpu"; } }; /** @@ -33,12 +36,19 @@ struct DeviceOrd { static bst_d_ordinal_t constexpr CPUOrdinal() { return -1; } static bst_d_ordinal_t constexpr InvalidOrdinal() { return -2; } - enum Type : std::int16_t { kCPU = 0, kCUDA = 1 } device{kCPU}; - // CUDA device ordinal. + enum Type : std::int16_t { kCPU = 0, kCUDA = 1, + kSyclDefault = 2, kSyclCPU = 3, kSyclGPU = 4} device{kCPU}; + // CUDA or Sycl device ordinal. bst_d_ordinal_t ordinal{CPUOrdinal()}; [[nodiscard]] bool IsCUDA() const { return device == kCUDA; } [[nodiscard]] bool IsCPU() const { return device == kCPU; } + [[nodiscard]] bool IsSyclDefault() const { return device == kSyclDefault; } + [[nodiscard]] bool IsSyclCPU() const { return device == kSyclCPU; } + [[nodiscard]] bool IsSyclGPU() const { return device == kSyclGPU; } + [[nodiscard]] bool IsSycl() const { return (IsSyclDefault() || + IsSyclCPU() || + IsSyclGPU()); } constexpr DeviceOrd() = default; constexpr DeviceOrd(Type type, bst_d_ordinal_t ord) : device{type}, ordinal{ord} {} @@ -60,6 +70,31 @@ struct DeviceOrd { [[nodiscard]] static constexpr auto CUDA(bst_d_ordinal_t ordinal) { return DeviceOrd{kCUDA, ordinal}; } + /** + * @brief Constructor for SYCL. + * + * @param ordinal SYCL device ordinal. + */ + [[nodiscard]] constexpr static auto SyclDefault(bst_d_ordinal_t ordinal = -1) { + return DeviceOrd{kSyclDefault, ordinal}; + } + /** + * @brief Constructor for SYCL CPU. + * + * @param ordinal SYCL CPU device ordinal. + */ + [[nodiscard]] constexpr static auto SyclCPU(bst_d_ordinal_t ordinal = -1) { + return DeviceOrd{kSyclCPU, ordinal}; + } + + /** + * @brief Constructor for SYCL GPU. + * + * @param ordinal SYCL GPU device ordinal. + */ + [[nodiscard]] constexpr static auto SyclGPU(bst_d_ordinal_t ordinal = -1) { + return DeviceOrd{kSyclGPU, ordinal}; + } [[nodiscard]] bool operator==(DeviceOrd const& that) const { return device == that.device && ordinal == that.ordinal; @@ -74,6 +109,12 @@ struct DeviceOrd { return DeviceSym::CPU(); case DeviceOrd::kCUDA: return DeviceSym::CUDA() + (':' + std::to_string(ordinal)); + case DeviceOrd::kSyclDefault: + return DeviceSym::SyclDefault() + (':' + std::to_string(ordinal)); + case DeviceOrd::kSyclCPU: + return DeviceSym::SyclCPU() + (':' + std::to_string(ordinal)); + case DeviceOrd::kSyclGPU: + return DeviceSym::SyclGPU() + (':' + std::to_string(ordinal)); default: { LOG(FATAL) << "Unknown device."; return ""; @@ -142,6 +183,25 @@ struct Context : public XGBoostParameter { * @brief Is XGBoost running on a CUDA device? */ [[nodiscard]] bool IsCUDA() const { return Device().IsCUDA(); } + /** + * @brief Is XGBoost running on the default SYCL device? + */ + [[nodiscard]] bool IsSyclDefault() const { return Device().IsSyclDefault(); } + /** + * @brief Is XGBoost running on a SYCL CPU? + */ + [[nodiscard]] bool IsSyclCPU() const { return Device().IsSyclCPU(); } + /** + * @brief Is XGBoost running on a SYCL GPU? + */ + [[nodiscard]] bool IsSyclGPU() const { return Device().IsSyclGPU(); } + /** + * @brief Is XGBoost running on any SYCL device? + */ + [[nodiscard]] bool IsSycl() const { return IsSyclDefault() + || IsSyclCPU() + || IsSyclGPU(); } + /** * @brief Get the current device and ordinal. */ @@ -175,6 +235,7 @@ struct Context : public XGBoostParameter { Context ctx = *this; return ctx.SetDevice(DeviceOrd::CPU()); } + /** * @brief Call function based on the current device. */ @@ -196,6 +257,20 @@ struct Context : public XGBoostParameter { return std::invoke_result_t(); } + /** + * @brief Call function for sycl devices + */ + template + decltype(auto) DispatchDevice(CPUFn&& cpu_fn, CUDAFn&& cuda_fn, SYCLFn&& sycl_fn) const { + static_assert(std::is_same_v, std::invoke_result_t>); + static_assert(std::is_same_v, std::invoke_result_t>); + if (this->Device().IsSycl()) { + return sycl_fn(); + } else { + return DispatchDevice(cpu_fn, cuda_fn); + } + } + // declare parameters DMLC_DECLARE_PARAMETER(Context) { DMLC_DECLARE_FIELD(seed) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 9fe73005a..bad3a2382 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -347,7 +347,7 @@ class _SparkXGBParams( def _validate_gpu_params(self) -> None: """Validate the gpu parameters and gpu configurations""" - if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu): + if self._run_on_gpu(): ss = _get_spark_session() sc = ss.sparkContext @@ -414,9 +414,7 @@ class _SparkXGBParams( ) if self.getOrDefault(self.features_cols): - if not use_cuda(self.getOrDefault(self.device)) and not self.getOrDefault( - self.use_gpu - ): + if not self._run_on_gpu(): raise ValueError( "features_col param with list value requires `device=cuda`." ) @@ -473,6 +471,15 @@ class _SparkXGBParams( self._validate_gpu_params() + def _run_on_gpu(self) -> bool: + """If train or transform on the gpu according to the parameters""" + + return ( + use_cuda(self.getOrDefault(self.device)) + or self.getOrDefault(self.use_gpu) + or self.getOrDefault(self.getParam("tree_method")) == "gpu_hist" + ) + def _validate_and_convert_feature_col_as_float_col_list( dataset: DataFrame, features_col_names: List[str] @@ -905,7 +912,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): """Check if stage-level scheduling is not needed, return true to skip stage-level scheduling""" - if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu): + if self._run_on_gpu(): ss = _get_spark_session() sc = ss.sparkContext @@ -1022,9 +1029,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): dmatrix_kwargs, ) = self._get_xgb_parameters(dataset) - run_on_gpu = use_cuda(self.getOrDefault(self.device)) or self.getOrDefault( - self.use_gpu - ) + run_on_gpu = self._run_on_gpu() + is_local = _is_local(_get_spark_session().sparkContext) num_workers = self.getOrDefault(self.num_workers) @@ -1318,12 +1324,15 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable): dataset = dataset.drop(pred_struct_col) return dataset - def _gpu_transform(self) -> bool: - """If gpu is used to do the prediction, true to gpu prediction""" + def _run_on_gpu(self) -> bool: + """If gpu is used to do the prediction according to the parameters + and spark configurations""" + + use_gpu_by_params = super()._run_on_gpu() if _is_local(_get_spark_session().sparkContext): - # if it's local model, we just use the internal "device" - return use_cuda(self.getOrDefault(self.device)) + # if it's local model, no need to check the spark configurations + return use_gpu_by_params gpu_per_task = ( _get_spark_session() @@ -1333,15 +1342,15 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable): # User don't set gpu configurations, just use cpu if gpu_per_task is None: - if use_cuda(self.getOrDefault(self.device)): + if use_gpu_by_params: get_logger("XGBoost-PySpark").warning( "Do the prediction on the CPUs since " "no gpu configurations are set" ) return False - # User already sets the gpu configurations, we just use the internal "device". - return use_cuda(self.getOrDefault(self.device)) + # User already sets the gpu configurations. + return use_gpu_by_params def _transform(self, dataset: DataFrame) -> DataFrame: # pylint: disable=too-many-statements, too-many-locals @@ -1367,7 +1376,7 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable): _, schema = self._out_schema() is_local = _is_local(_get_spark_session().sparkContext) - run_on_gpu = self._gpu_transform() + run_on_gpu = self._run_on_gpu() @pandas_udf(schema) # type: ignore def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]: @@ -1381,9 +1390,10 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable): dev_ordinal = -1 - if is_cudf_available(): - if is_local: - if run_on_gpu and is_cupy_available(): + msg = "Do the inference on the CPUs" + if run_on_gpu: + if is_cudf_available() and is_cupy_available(): + if is_local: import cupy as cp # pylint: disable=import-error total_gpus = cp.cuda.runtime.getDeviceCount() @@ -1392,24 +1402,19 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable): # For transform local mode, default the dev_ordinal to # (partition id) % gpus. dev_ordinal = partition_id % total_gpus - elif run_on_gpu: - dev_ordinal = _get_gpu_id(context) + else: + dev_ordinal = _get_gpu_id(context) - if dev_ordinal >= 0: - device = "cuda:" + str(dev_ordinal) - get_logger("XGBoost-PySpark").info( - "Do the inference with device: %s", device - ) - model.set_params(device=device) + if dev_ordinal >= 0: + device = "cuda:" + str(dev_ordinal) + msg = "Do the inference with device: " + device + model.set_params(device=device) + else: + msg = "Couldn't get the correct gpu id, fallback the inference on the CPUs" else: - get_logger("XGBoost-PySpark").info("Do the inference on the CPUs") - else: - msg = ( - "CUDF is unavailable, fallback the inference on the CPUs" - if run_on_gpu - else "Do the inference on the CPUs" - ) - get_logger("XGBoost-PySpark").info(msg) + msg = "CUDF or Cupy is unavailable, fallback the inference on the CPUs" + + get_logger("XGBoost-PySpark").info(msg) def to_gpu_if_possible(data: ArrayLike) -> ArrayLike: """Move the data to gpu if possible""" diff --git a/src/context.cc b/src/context.cc index 7b74a69e0..108ad9ce7 100644 --- a/src/context.cc +++ b/src/context.cc @@ -104,19 +104,44 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) { // mingw hangs on regex using rtools 430. Basic checks only. CHECK_GE(input.size(), 3) << msg; auto substr = input.substr(0, 3); - bool valid = substr == "cpu" || substr == "cud" || substr == "gpu"; + bool valid = substr == "cpu" || substr == "cud" || substr == "gpu" || substr == "syc"; CHECK(valid) << msg; #else - std::regex pattern{"gpu(:[0-9]+)?|cuda(:[0-9]+)?|cpu"}; + std::regex pattern{"gpu(:[0-9]+)?|cuda(:[0-9]+)?|cpu|sycl(:cpu|:gpu)?(:-1|:[0-9]+)?"}; if (!std::regex_match(input, pattern)) { fatal(); } #endif // defined(__MINGW32__) // handle alias - std::string s_device = std::regex_replace(input, std::regex{"gpu"}, DeviceSym::CUDA()); +#if defined(__MINGW32__) + // mingw hangs on regex using rtools 430. Basic checks only. + bool is_sycl = (substr == "syc"); +#else + bool is_sycl = std::regex_match(input, std::regex("sycl(:cpu|:gpu)?(:-1|:[0-9]+)?")); +#endif // defined(__MINGW32__) + + std::string s_device = input; + if (!is_sycl) { + s_device = std::regex_replace(s_device, std::regex{"gpu"}, DeviceSym::CUDA()); + } auto split_it = std::find(s_device.cbegin(), s_device.cend(), ':'); + + // For these cases we need to move iterator to the end, not to look for a ordinal. + if ((s_device == "sycl:cpu") || + (s_device == "sycl:gpu")) { + split_it = s_device.cend(); + } + + // For s_device like "sycl:gpu:1" + if (split_it != s_device.cend()) { + auto second_split_it = std::find(split_it + 1, s_device.cend(), ':'); + if (second_split_it != s_device.cend()) { + split_it = second_split_it; + } + } + DeviceOrd device; device.ordinal = DeviceOrd::InvalidOrdinal(); // mark it invalid for check. if (split_it == s_device.cend()) { @@ -125,15 +150,22 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) { device = DeviceOrd::CPU(); } else if (s_device == DeviceSym::CUDA()) { device = DeviceOrd::CUDA(0); // use 0 as default; + } else if (s_device == DeviceSym::SyclDefault()) { + device = DeviceOrd::SyclDefault(); + } else if (s_device == DeviceSym::SyclCPU()) { + device = DeviceOrd::SyclCPU(); + } else if (s_device == DeviceSym::SyclGPU()) { + device = DeviceOrd::SyclGPU(); } else { fatal(); } } else { - // must be CUDA when ordinal is specifed. + // must be CUDA or SYCL when ordinal is specifed. // +1 for colon std::size_t offset = std::distance(s_device.cbegin(), split_it) + 1; // substr StringView s_ordinal = {s_device.data() + offset, s_device.size() - offset}; + StringView s_type = {s_device.data(), offset - 1}; if (s_ordinal.empty()) { fatal(); } @@ -143,13 +175,23 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) { } CHECK_LE(opt_id.value(), std::numeric_limits::max()) << "Ordinal value too large."; - device = DeviceOrd::CUDA(opt_id.value()); + if (s_type == DeviceSym::SyclDefault()) { + device = DeviceOrd::SyclDefault(opt_id.value()); + } else if (s_type == DeviceSym::SyclCPU()) { + device = DeviceOrd::SyclCPU(opt_id.value()); + } else if (s_type == DeviceSym::SyclGPU()) { + device = DeviceOrd::SyclGPU(opt_id.value()); + } else { + device = DeviceOrd::CUDA(opt_id.value()); + } } if (device.ordinal < DeviceOrd::CPUOrdinal()) { fatal(); } - device = CUDAOrdinal(device, fail_on_invalid_gpu_id); + if (device.IsCUDA()) { + device = CUDAOrdinal(device, fail_on_invalid_gpu_id); + } return device; } @@ -216,7 +258,7 @@ void Context::SetDeviceOrdinal(Args const& kwargs) { if (this->IsCPU()) { CHECK_EQ(this->device_.ordinal, DeviceOrd::CPUOrdinal()); - } else { + } else if (this->IsCUDA()) { CHECK_GT(this->device_.ordinal, DeviceOrd::CPUOrdinal()); } } diff --git a/tests/cpp/test_context.cc b/tests/cpp/test_context.cc index 2fdf04aa1..4eb765c93 100644 --- a/tests/cpp/test_context.cc +++ b/tests/cpp/test_context.cc @@ -45,4 +45,97 @@ TEST(Context, ErrorInit) { ASSERT_NE(msg.find("foo"), std::string::npos); } } + +TEST(Context, SYCL) { + Context ctx; + // Default SYCL device + { + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + ASSERT_EQ(ctx.Device(), DeviceOrd::SyclDefault()); + ASSERT_EQ(ctx.Ordinal(), -1); + + std::int32_t flag{0}; + ctx.DispatchDevice([&] { flag = -1; }, [&] { flag = 1; }, [&] { flag = 2; }); + ASSERT_EQ(flag, 2); + + std::stringstream ss; + ss << ctx.Device(); + ASSERT_EQ(ss.str(), "sycl:-1"); + } + + // SYCL device with idx + { + ctx.UpdateAllowUnknown(Args{{"device", "sycl:42"}}); + ASSERT_EQ(ctx.Device(), DeviceOrd::SyclDefault(42)); + ASSERT_EQ(ctx.Ordinal(), 42); + + std::int32_t flag{0}; + ctx.DispatchDevice([&] { flag = -1; }, [&] { flag = 1; }, [&] { flag = 2; }); + ASSERT_EQ(flag, 2); + + std::stringstream ss; + ss << ctx.Device(); + ASSERT_EQ(ss.str(), "sycl:42"); + } + + // SYCL cpu + { + ctx.UpdateAllowUnknown(Args{{"device", "sycl:cpu"}}); + ASSERT_EQ(ctx.Device(), DeviceOrd::SyclCPU()); + ASSERT_EQ(ctx.Ordinal(), -1); + + std::int32_t flag{0}; + ctx.DispatchDevice([&] { flag = -1; }, [&] { flag = 1; }, [&] { flag = 2; }); + ASSERT_EQ(flag, 2); + + std::stringstream ss; + ss << ctx.Device(); + ASSERT_EQ(ss.str(), "sycl:cpu:-1"); + } + + // SYCL cpu with idx + { + ctx.UpdateAllowUnknown(Args{{"device", "sycl:cpu:42"}}); + ASSERT_EQ(ctx.Device(), DeviceOrd::SyclCPU(42)); + ASSERT_EQ(ctx.Ordinal(), 42); + + std::int32_t flag{0}; + ctx.DispatchDevice([&] { flag = -1; }, [&] { flag = 1; }, [&] { flag = 2; }); + ASSERT_EQ(flag, 2); + + std::stringstream ss; + ss << ctx.Device(); + ASSERT_EQ(ss.str(), "sycl:cpu:42"); + } + + // SYCL gpu + { + ctx.UpdateAllowUnknown(Args{{"device", "sycl:gpu"}}); + ASSERT_EQ(ctx.Device(), DeviceOrd::SyclGPU()); + ASSERT_EQ(ctx.Ordinal(), -1); + + std::int32_t flag{0}; + ctx.DispatchDevice([&] { flag = -1; }, [&] { flag = 1; }, [&] { flag = 2; }); + ASSERT_EQ(flag, 2); + + std::stringstream ss; + ss << ctx.Device(); + ASSERT_EQ(ss.str(), "sycl:gpu:-1"); + } + + // SYCL gpu with idx + { + ctx.UpdateAllowUnknown(Args{{"device", "sycl:gpu:42"}}); + ASSERT_EQ(ctx.Device(), DeviceOrd::SyclGPU(42)); + ASSERT_EQ(ctx.Ordinal(), 42); + + std::int32_t flag{0}; + ctx.DispatchDevice([&] { flag = -1; }, [&] { flag = 1; }, [&] { flag = 2; }); + ASSERT_EQ(flag, 2); + + std::stringstream ss; + ss << ctx.Device(); + ASSERT_EQ(ss.str(), "sycl:gpu:42"); + } +} } // namespace xgboost diff --git a/tests/test_distributed/test_gpu_with_spark/test_gpu_spark.py b/tests/test_distributed/test_gpu_with_spark/test_gpu_spark.py index 513554e43..3bf94c954 100644 --- a/tests/test_distributed/test_gpu_with_spark/test_gpu_spark.py +++ b/tests/test_distributed/test_gpu_with_spark/test_gpu_spark.py @@ -251,10 +251,10 @@ def test_gpu_transform(spark_diabetes_dataset) -> None: model: SparkXGBRegressorModel = regressor.fit(train_df) # The model trained with GPUs, and transform with GPU configurations. - assert model._gpu_transform() + assert model._run_on_gpu() model.set_device("cpu") - assert not model._gpu_transform() + assert not model._run_on_gpu() # without error cpu_rows = model.transform(test_df).select("prediction").collect() @@ -263,11 +263,11 @@ def test_gpu_transform(spark_diabetes_dataset) -> None: # The model trained with CPUs. Even with GPU configurations, # still prefer transforming with CPUs - assert not model._gpu_transform() + assert not model._run_on_gpu() # Set gpu transform explicitly. model.set_device("cuda") - assert model._gpu_transform() + assert model._run_on_gpu() # without error gpu_rows = model.transform(test_df).select("prediction").collect() diff --git a/tests/test_distributed/test_with_spark/test_spark_local.py b/tests/test_distributed/test_with_spark/test_spark_local.py index 861e67a75..2c5ee3690 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local.py +++ b/tests/test_distributed/test_with_spark/test_spark_local.py @@ -888,6 +888,22 @@ class TestPySparkLocal: clf = SparkXGBClassifier(device="cuda") clf._validate_params() + def test_gpu_params(self) -> None: + clf = SparkXGBClassifier() + assert not clf._run_on_gpu() + + clf = SparkXGBClassifier(device="cuda", tree_method="hist") + assert clf._run_on_gpu() + + clf = SparkXGBClassifier(device="cuda") + assert clf._run_on_gpu() + + clf = SparkXGBClassifier(tree_method="gpu_hist") + assert clf._run_on_gpu() + + clf = SparkXGBClassifier(use_gpu=True) + assert clf._run_on_gpu() + def test_gpu_transform(self, clf_data: ClfData) -> None: """local mode""" classifier = SparkXGBClassifier(device="cpu") @@ -898,23 +914,23 @@ class TestPySparkLocal: model.write().overwrite().save(path) # The model trained with CPU, transform defaults to cpu - assert not model._gpu_transform() + assert not model._run_on_gpu() # without error model.transform(clf_data.cls_df_test).collect() model.set_device("cuda") - assert model._gpu_transform() + assert model._run_on_gpu() model_loaded = SparkXGBClassifierModel.load(path) # The model trained with CPU, transform defaults to cpu - assert not model_loaded._gpu_transform() + assert not model_loaded._run_on_gpu() # without error model_loaded.transform(clf_data.cls_df_test).collect() model_loaded.set_device("cuda") - assert model_loaded._gpu_transform() + assert model_loaded._run_on_gpu() class XgboostLocalTest(SparkTestCase):