Handle the new device parameter in dask and demos. (#9386)

* Handle the new `device` parameter in dask and demos.

- Check no ordinal is specified in the dask interface.
- Update demos.
- Update dask doc.
- Update the condition for QDM.
This commit is contained in:
Jiaming Yuan
2023-07-15 19:11:20 +08:00
committed by GitHub
parent 9da5050643
commit 16eb41936d
31 changed files with 631 additions and 450 deletions

View File

@@ -1451,7 +1451,7 @@ class QuantileDMatrix(DMatrix):
enable_categorical: bool = False,
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> None:
self.max_bin: int = max_bin if max_bin is not None else 256
self.max_bin = max_bin
self.missing = missing if missing is not None else np.nan
self.nthread = nthread if nthread is not None else -1
self._silent = silent # unused, kept for compatibility

View File

@@ -82,6 +82,7 @@ from .sklearn import (
XGBRanker,
XGBRankerMixIn,
XGBRegressorBase,
_can_use_qdm,
_check_rf_callback,
_cls_predict_proba,
_objective_decorator,
@@ -617,14 +618,7 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
if self._iter == len(self._data):
# Return 0 when there's no more batch.
return 0
feature_names: Optional[FeatureNames] = None
if self._feature_names:
feature_names = self._feature_names
else:
if hasattr(self.data(), "columns"):
feature_names = self.data().columns.format()
else:
feature_names = None
input_data(
data=self.data(),
label=self._get("_label"),
@@ -634,7 +628,7 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
base_margin=self._get("_base_margin"),
label_lower_bound=self._get("_label_lower_bound"),
label_upper_bound=self._get("_label_upper_bound"),
feature_names=feature_names,
feature_names=self._feature_names,
feature_types=self._feature_types,
feature_weights=self._feature_weights,
)
@@ -935,6 +929,12 @@ async def _train_async(
raise NotImplementedError(
f"booster `{params['booster']}` is not yet supported for dask."
)
device = params.get("device", None)
if device and device.find(":") != -1:
raise ValueError(
"The dask interface for XGBoost doesn't support selecting specific device"
" ordinal. Use `device=cpu` or `device=cuda` instead."
)
def dispatched_train(
parameters: Dict,
@@ -1574,7 +1574,7 @@ async def _async_wrap_evaluation_matrices(
"""A switch function for async environment."""
def _dispatch(ref: Optional[DaskDMatrix], **kwargs: Any) -> DaskDMatrix:
if tree_method in ("hist", "gpu_hist"):
if _can_use_qdm(tree_method):
return DaskQuantileDMatrix(
client=client, ref=ref, max_bin=max_bin, **kwargs
)

View File

@@ -76,6 +76,10 @@ def _check_rf_callback(
)
def _can_use_qdm(tree_method: Optional[str]) -> bool:
return tree_method in ("hist", "gpu_hist", None, "auto")
SklObjective = Optional[
Union[str, Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]]
]
@@ -939,7 +943,7 @@ class XGBModel(XGBModelBase):
def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix:
# Use `QuantileDMatrix` to save memory.
if self.tree_method in ("hist", "gpu_hist"):
if _can_use_qdm(self.tree_method) and self.booster != "gblinear":
try:
return QuantileDMatrix(
**kwargs, ref=ref, nthread=self.n_jobs, max_bin=self.max_bin

View File

@@ -61,7 +61,7 @@ import xgboost
from xgboost import XGBClassifier
from xgboost.compat import is_cudf_available
from xgboost.core import Booster
from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel
from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel, _can_use_qdm
from xgboost.training import train as worker_train
from .data import (
@@ -901,7 +901,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
context = BarrierTaskContext.get()
dev_ordinal = None
use_hist = booster_params.get("tree_method", None) in ("hist", "gpu_hist")
use_qdm = _can_use_qdm(booster_params.get("tree_method", None))
if use_gpu:
dev_ordinal = (
@@ -912,9 +912,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
# because without cuDF, DMatrix performs better than QDM.
# Note: Checking `is_cudf_available` in spark worker side because
# spark worker might has different python environment with driver side.
use_qdm = use_hist and is_cudf_available()
else:
use_qdm = use_hist
use_qdm = use_qdm and is_cudf_available()
if use_qdm and (booster_params.get("max_bin", None) is not None):
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]