From bba6aa74fbb0eb67b1a2f8f5b0fe3868d33b9df0 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Mon, 9 Sep 2024 08:14:07 +0200 Subject: [PATCH] [SYCL] Fix for sycl support with sklearn estimators (#10806) --------- Co-authored-by: Dmitry Razdoburdin <> --- python-package/xgboost/dask/__init__.py | 6 +++- python-package/xgboost/sklearn.py | 7 ++-- python-package/xgboost/spark/core.py | 5 ++- tests/python-sycl/test_sycl_with_sklearn.py | 37 +++++++++++++++++++++ 4 files changed, 50 insertions(+), 5 deletions(-) create mode 100644 tests/python-sycl/test_sycl_with_sklearn.py diff --git a/python-package/xgboost/dask/__init__.py b/python-package/xgboost/dask/__init__.py index 7a565f5f2..a2edd26b9 100644 --- a/python-package/xgboost/dask/__init__.py +++ b/python-package/xgboost/dask/__init__.py @@ -1568,6 +1568,7 @@ def inplace_predict( # pylint: disable=unused-argument async def _async_wrap_evaluation_matrices( client: Optional["distributed.Client"], + device: Optional[str], tree_method: Optional[str], max_bin: Optional[int], **kwargs: Any, @@ -1575,7 +1576,7 @@ async def _async_wrap_evaluation_matrices( """A switch function for async environment.""" def _dispatch(ref: Optional[DaskDMatrix], **kwargs: Any) -> DaskDMatrix: - if _can_use_qdm(tree_method): + if _can_use_qdm(tree_method, device): return DaskQuantileDMatrix( client=client, ref=ref, max_bin=max_bin, **kwargs ) @@ -1776,6 +1777,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase): params = self.get_xgb_params() dtrain, evals = await _async_wrap_evaluation_matrices( client=self.client, + device=self.device, tree_method=self.tree_method, max_bin=self.max_bin, X=X, @@ -1865,6 +1867,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): params = self.get_xgb_params() dtrain, evals = await _async_wrap_evaluation_matrices( self.client, + device=self.device, tree_method=self.tree_method, max_bin=self.max_bin, X=X, @@ -2067,6 +2070,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn): params = self.get_xgb_params() dtrain, evals = await _async_wrap_evaluation_matrices( self.client, + device=self.device, tree_method=self.tree_method, max_bin=self.max_bin, X=X, diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index e295246e1..45a1d4b67 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -67,8 +67,9 @@ def _check_rf_callback( ) -def _can_use_qdm(tree_method: Optional[str]) -> bool: - return tree_method in ("hist", "gpu_hist", None, "auto") +def _can_use_qdm(tree_method: Optional[str], device: Optional[str]) -> bool: + not_sycl = (device is None) or (not device.startswith("sycl")) + return tree_method in ("hist", "gpu_hist", None, "auto") and not_sycl class _SklObjWProto(Protocol): # pylint: disable=too-few-public-methods @@ -1031,7 +1032,7 @@ class XGBModel(XGBModelBase): def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix: # Use `QuantileDMatrix` to save memory. - if _can_use_qdm(self.tree_method) and self.booster != "gblinear": + if _can_use_qdm(self.tree_method, self.device) and self.booster != "gblinear": try: return QuantileDMatrix( **kwargs, ref=ref, nthread=self.n_jobs, max_bin=self.max_bin diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 6700aeed8..7eef43842 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -1028,7 +1028,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): context = BarrierTaskContext.get() dev_ordinal = None - use_qdm = _can_use_qdm(booster_params.get("tree_method", None)) + use_qdm = _can_use_qdm( + booster_params.get("tree_method", None), + booster_params.get("device", None), + ) verbosity = booster_params.get("verbosity", 1) msg = "Training on CPUs" if run_on_gpu: diff --git a/tests/python-sycl/test_sycl_with_sklearn.py b/tests/python-sycl/test_sycl_with_sklearn.py new file mode 100644 index 000000000..8e75e77f8 --- /dev/null +++ b/tests/python-sycl/test_sycl_with_sklearn.py @@ -0,0 +1,37 @@ +import xgboost as xgb +import pytest +import sys +import numpy as np + +from xgboost import testing as tm + +sys.path.append("tests/python") +import test_with_sklearn as twskl # noqa + +pytestmark = pytest.mark.skipif(**tm.no_sklearn()) + +rng = np.random.RandomState(1994) + + +def test_sycl_binary_classification(): + from sklearn.datasets import load_digits + from sklearn.model_selection import KFold + + digits = load_digits(n_class=2) + y = digits["target"] + X = digits["data"] + kf = KFold(n_splits=2, shuffle=True, random_state=rng) + for cls in (xgb.XGBClassifier, xgb.XGBRFClassifier): + for train_index, test_index in kf.split(X, y): + xgb_model = cls(random_state=42, device="sycl", n_estimators=4).fit( + X[train_index], y[train_index] + ) + preds = xgb_model.predict(X[test_index]) + labels = y[test_index] + err = sum( + 1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i] + ) / float(len(preds)) + print(preds) + print(labels) + print(err) + assert err < 0.1