[SYCL] Fix for sycl support with sklearn estimators (#10806)

---------

Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
Dmitry Razdoburdin
2024-09-09 08:14:07 +02:00
committed by GitHub
parent 5f7f31d464
commit bba6aa74fb
4 changed files with 50 additions and 5 deletions

View File

@@ -1568,6 +1568,7 @@ def inplace_predict( # pylint: disable=unused-argument
async def _async_wrap_evaluation_matrices(
client: Optional["distributed.Client"],
device: Optional[str],
tree_method: Optional[str],
max_bin: Optional[int],
**kwargs: Any,
@@ -1575,7 +1576,7 @@ async def _async_wrap_evaluation_matrices(
"""A switch function for async environment."""
def _dispatch(ref: Optional[DaskDMatrix], **kwargs: Any) -> DaskDMatrix:
if _can_use_qdm(tree_method):
if _can_use_qdm(tree_method, device):
return DaskQuantileDMatrix(
client=client, ref=ref, max_bin=max_bin, **kwargs
)
@@ -1776,6 +1777,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices(
client=self.client,
device=self.device,
tree_method=self.tree_method,
max_bin=self.max_bin,
X=X,
@@ -1865,6 +1867,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices(
self.client,
device=self.device,
tree_method=self.tree_method,
max_bin=self.max_bin,
X=X,
@@ -2067,6 +2070,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices(
self.client,
device=self.device,
tree_method=self.tree_method,
max_bin=self.max_bin,
X=X,

View File

@@ -67,8 +67,9 @@ def _check_rf_callback(
)
def _can_use_qdm(tree_method: Optional[str]) -> bool:
return tree_method in ("hist", "gpu_hist", None, "auto")
def _can_use_qdm(tree_method: Optional[str], device: Optional[str]) -> bool:
not_sycl = (device is None) or (not device.startswith("sycl"))
return tree_method in ("hist", "gpu_hist", None, "auto") and not_sycl
class _SklObjWProto(Protocol): # pylint: disable=too-few-public-methods
@@ -1031,7 +1032,7 @@ class XGBModel(XGBModelBase):
def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix:
# Use `QuantileDMatrix` to save memory.
if _can_use_qdm(self.tree_method) and self.booster != "gblinear":
if _can_use_qdm(self.tree_method, self.device) and self.booster != "gblinear":
try:
return QuantileDMatrix(
**kwargs, ref=ref, nthread=self.n_jobs, max_bin=self.max_bin

View File

@@ -1028,7 +1028,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
context = BarrierTaskContext.get()
dev_ordinal = None
use_qdm = _can_use_qdm(booster_params.get("tree_method", None))
use_qdm = _can_use_qdm(
booster_params.get("tree_method", None),
booster_params.get("device", None),
)
verbosity = booster_params.get("verbosity", 1)
msg = "Training on CPUs"
if run_on_gpu: