[SYCL] Fix for sycl support with sklearn estimators (#10806)

---------

Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
Dmitry Razdoburdin 2024-09-09 08:14:07 +02:00 committed by GitHub
parent 5f7f31d464
commit bba6aa74fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 50 additions and 5 deletions

View File

@ -1568,6 +1568,7 @@ def inplace_predict( # pylint: disable=unused-argument
async def _async_wrap_evaluation_matrices( async def _async_wrap_evaluation_matrices(
client: Optional["distributed.Client"], client: Optional["distributed.Client"],
device: Optional[str],
tree_method: Optional[str], tree_method: Optional[str],
max_bin: Optional[int], max_bin: Optional[int],
**kwargs: Any, **kwargs: Any,
@ -1575,7 +1576,7 @@ async def _async_wrap_evaluation_matrices(
"""A switch function for async environment.""" """A switch function for async environment."""
def _dispatch(ref: Optional[DaskDMatrix], **kwargs: Any) -> DaskDMatrix: def _dispatch(ref: Optional[DaskDMatrix], **kwargs: Any) -> DaskDMatrix:
if _can_use_qdm(tree_method): if _can_use_qdm(tree_method, device):
return DaskQuantileDMatrix( return DaskQuantileDMatrix(
client=client, ref=ref, max_bin=max_bin, **kwargs client=client, ref=ref, max_bin=max_bin, **kwargs
) )
@ -1776,6 +1777,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
params = self.get_xgb_params() params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices( dtrain, evals = await _async_wrap_evaluation_matrices(
client=self.client, client=self.client,
device=self.device,
tree_method=self.tree_method, tree_method=self.tree_method,
max_bin=self.max_bin, max_bin=self.max_bin,
X=X, X=X,
@ -1865,6 +1867,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
params = self.get_xgb_params() params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices( dtrain, evals = await _async_wrap_evaluation_matrices(
self.client, self.client,
device=self.device,
tree_method=self.tree_method, tree_method=self.tree_method,
max_bin=self.max_bin, max_bin=self.max_bin,
X=X, X=X,
@ -2067,6 +2070,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
params = self.get_xgb_params() params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices( dtrain, evals = await _async_wrap_evaluation_matrices(
self.client, self.client,
device=self.device,
tree_method=self.tree_method, tree_method=self.tree_method,
max_bin=self.max_bin, max_bin=self.max_bin,
X=X, X=X,

View File

@ -67,8 +67,9 @@ def _check_rf_callback(
) )
def _can_use_qdm(tree_method: Optional[str]) -> bool: def _can_use_qdm(tree_method: Optional[str], device: Optional[str]) -> bool:
return tree_method in ("hist", "gpu_hist", None, "auto") not_sycl = (device is None) or (not device.startswith("sycl"))
return tree_method in ("hist", "gpu_hist", None, "auto") and not_sycl
class _SklObjWProto(Protocol): # pylint: disable=too-few-public-methods class _SklObjWProto(Protocol): # pylint: disable=too-few-public-methods
@ -1031,7 +1032,7 @@ class XGBModel(XGBModelBase):
def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix: def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix:
# Use `QuantileDMatrix` to save memory. # Use `QuantileDMatrix` to save memory.
if _can_use_qdm(self.tree_method) and self.booster != "gblinear": if _can_use_qdm(self.tree_method, self.device) and self.booster != "gblinear":
try: try:
return QuantileDMatrix( return QuantileDMatrix(
**kwargs, ref=ref, nthread=self.n_jobs, max_bin=self.max_bin **kwargs, ref=ref, nthread=self.n_jobs, max_bin=self.max_bin

View File

@ -1028,7 +1028,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
context = BarrierTaskContext.get() context = BarrierTaskContext.get()
dev_ordinal = None dev_ordinal = None
use_qdm = _can_use_qdm(booster_params.get("tree_method", None)) use_qdm = _can_use_qdm(
booster_params.get("tree_method", None),
booster_params.get("device", None),
)
verbosity = booster_params.get("verbosity", 1) verbosity = booster_params.get("verbosity", 1)
msg = "Training on CPUs" msg = "Training on CPUs"
if run_on_gpu: if run_on_gpu:

View File

@ -0,0 +1,37 @@
import xgboost as xgb
import pytest
import sys
import numpy as np
from xgboost import testing as tm
sys.path.append("tests/python")
import test_with_sklearn as twskl # noqa
pytestmark = pytest.mark.skipif(**tm.no_sklearn())
rng = np.random.RandomState(1994)
def test_sycl_binary_classification():
from sklearn.datasets import load_digits
from sklearn.model_selection import KFold
digits = load_digits(n_class=2)
y = digits["target"]
X = digits["data"]
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for cls in (xgb.XGBClassifier, xgb.XGBRFClassifier):
for train_index, test_index in kf.split(X, y):
xgb_model = cls(random_state=42, device="sycl", n_estimators=4).fit(
X[train_index], y[train_index]
)
preds = xgb_model.predict(X[test_index])
labels = y[test_index]
err = sum(
1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]
) / float(len(preds))
print(preds)
print(labels)
print(err)
assert err < 0.1