[SYCL] Fix for sycl support with sklearn estimators (#10806)
--------- Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
parent
5f7f31d464
commit
bba6aa74fb
@ -1568,6 +1568,7 @@ def inplace_predict( # pylint: disable=unused-argument
|
|||||||
|
|
||||||
async def _async_wrap_evaluation_matrices(
|
async def _async_wrap_evaluation_matrices(
|
||||||
client: Optional["distributed.Client"],
|
client: Optional["distributed.Client"],
|
||||||
|
device: Optional[str],
|
||||||
tree_method: Optional[str],
|
tree_method: Optional[str],
|
||||||
max_bin: Optional[int],
|
max_bin: Optional[int],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -1575,7 +1576,7 @@ async def _async_wrap_evaluation_matrices(
|
|||||||
"""A switch function for async environment."""
|
"""A switch function for async environment."""
|
||||||
|
|
||||||
def _dispatch(ref: Optional[DaskDMatrix], **kwargs: Any) -> DaskDMatrix:
|
def _dispatch(ref: Optional[DaskDMatrix], **kwargs: Any) -> DaskDMatrix:
|
||||||
if _can_use_qdm(tree_method):
|
if _can_use_qdm(tree_method, device):
|
||||||
return DaskQuantileDMatrix(
|
return DaskQuantileDMatrix(
|
||||||
client=client, ref=ref, max_bin=max_bin, **kwargs
|
client=client, ref=ref, max_bin=max_bin, **kwargs
|
||||||
)
|
)
|
||||||
@ -1776,6 +1777,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
|||||||
params = self.get_xgb_params()
|
params = self.get_xgb_params()
|
||||||
dtrain, evals = await _async_wrap_evaluation_matrices(
|
dtrain, evals = await _async_wrap_evaluation_matrices(
|
||||||
client=self.client,
|
client=self.client,
|
||||||
|
device=self.device,
|
||||||
tree_method=self.tree_method,
|
tree_method=self.tree_method,
|
||||||
max_bin=self.max_bin,
|
max_bin=self.max_bin,
|
||||||
X=X,
|
X=X,
|
||||||
@ -1865,6 +1867,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
params = self.get_xgb_params()
|
params = self.get_xgb_params()
|
||||||
dtrain, evals = await _async_wrap_evaluation_matrices(
|
dtrain, evals = await _async_wrap_evaluation_matrices(
|
||||||
self.client,
|
self.client,
|
||||||
|
device=self.device,
|
||||||
tree_method=self.tree_method,
|
tree_method=self.tree_method,
|
||||||
max_bin=self.max_bin,
|
max_bin=self.max_bin,
|
||||||
X=X,
|
X=X,
|
||||||
@ -2067,6 +2070,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
|||||||
params = self.get_xgb_params()
|
params = self.get_xgb_params()
|
||||||
dtrain, evals = await _async_wrap_evaluation_matrices(
|
dtrain, evals = await _async_wrap_evaluation_matrices(
|
||||||
self.client,
|
self.client,
|
||||||
|
device=self.device,
|
||||||
tree_method=self.tree_method,
|
tree_method=self.tree_method,
|
||||||
max_bin=self.max_bin,
|
max_bin=self.max_bin,
|
||||||
X=X,
|
X=X,
|
||||||
|
|||||||
@ -67,8 +67,9 @@ def _check_rf_callback(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _can_use_qdm(tree_method: Optional[str]) -> bool:
|
def _can_use_qdm(tree_method: Optional[str], device: Optional[str]) -> bool:
|
||||||
return tree_method in ("hist", "gpu_hist", None, "auto")
|
not_sycl = (device is None) or (not device.startswith("sycl"))
|
||||||
|
return tree_method in ("hist", "gpu_hist", None, "auto") and not_sycl
|
||||||
|
|
||||||
|
|
||||||
class _SklObjWProto(Protocol): # pylint: disable=too-few-public-methods
|
class _SklObjWProto(Protocol): # pylint: disable=too-few-public-methods
|
||||||
@ -1031,7 +1032,7 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix:
|
def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix:
|
||||||
# Use `QuantileDMatrix` to save memory.
|
# Use `QuantileDMatrix` to save memory.
|
||||||
if _can_use_qdm(self.tree_method) and self.booster != "gblinear":
|
if _can_use_qdm(self.tree_method, self.device) and self.booster != "gblinear":
|
||||||
try:
|
try:
|
||||||
return QuantileDMatrix(
|
return QuantileDMatrix(
|
||||||
**kwargs, ref=ref, nthread=self.n_jobs, max_bin=self.max_bin
|
**kwargs, ref=ref, nthread=self.n_jobs, max_bin=self.max_bin
|
||||||
|
|||||||
@ -1028,7 +1028,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
context = BarrierTaskContext.get()
|
context = BarrierTaskContext.get()
|
||||||
|
|
||||||
dev_ordinal = None
|
dev_ordinal = None
|
||||||
use_qdm = _can_use_qdm(booster_params.get("tree_method", None))
|
use_qdm = _can_use_qdm(
|
||||||
|
booster_params.get("tree_method", None),
|
||||||
|
booster_params.get("device", None),
|
||||||
|
)
|
||||||
verbosity = booster_params.get("verbosity", 1)
|
verbosity = booster_params.get("verbosity", 1)
|
||||||
msg = "Training on CPUs"
|
msg = "Training on CPUs"
|
||||||
if run_on_gpu:
|
if run_on_gpu:
|
||||||
|
|||||||
37
tests/python-sycl/test_sycl_with_sklearn.py
Normal file
37
tests/python-sycl/test_sycl_with_sklearn.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import xgboost as xgb
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from xgboost import testing as tm
|
||||||
|
|
||||||
|
sys.path.append("tests/python")
|
||||||
|
import test_with_sklearn as twskl # noqa
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.skipif(**tm.no_sklearn())
|
||||||
|
|
||||||
|
rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sycl_binary_classification():
|
||||||
|
from sklearn.datasets import load_digits
|
||||||
|
from sklearn.model_selection import KFold
|
||||||
|
|
||||||
|
digits = load_digits(n_class=2)
|
||||||
|
y = digits["target"]
|
||||||
|
X = digits["data"]
|
||||||
|
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
||||||
|
for cls in (xgb.XGBClassifier, xgb.XGBRFClassifier):
|
||||||
|
for train_index, test_index in kf.split(X, y):
|
||||||
|
xgb_model = cls(random_state=42, device="sycl", n_estimators=4).fit(
|
||||||
|
X[train_index], y[train_index]
|
||||||
|
)
|
||||||
|
preds = xgb_model.predict(X[test_index])
|
||||||
|
labels = y[test_index]
|
||||||
|
err = sum(
|
||||||
|
1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]
|
||||||
|
) / float(len(preds))
|
||||||
|
print(preds)
|
||||||
|
print(labels)
|
||||||
|
print(err)
|
||||||
|
assert err < 0.1
|
||||||
Loading…
x
Reference in New Issue
Block a user