diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 4bdeb49e5..d571971a6 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -21,6 +21,7 @@ from contextlib import contextmanager from collections import defaultdict from collections.abc import Sequence from threading import Thread +from functools import partial, update_wrapper from typing import TYPE_CHECKING, List, Tuple, Callable, Optional, Any, Union, Dict, Set from typing import Awaitable, Generator, TypeVar @@ -967,7 +968,7 @@ def _can_output_df(is_df: bool, output_shape: Tuple) -> bool: return is_df and len(output_shape) <= 2 -async def _direct_predict_impl( +async def _direct_predict_impl( # pylint: disable=too-many-branches mapped_predict: Callable, booster: "distributed.Future", data: _DaskCollection, @@ -1022,6 +1023,14 @@ async def _direct_predict_impl( new_axis = list(range(len(output_shape) - 2)) else: new_axis = [i + 2 for i in range(len(output_shape) - 2)] + if len(output_shape) == 2: + # Somehow dask fail to infer output shape change for 2-dim prediction, and + # `chunks = (None, output_shape[1])` doesn't work due to None is not + # supported in map_blocks. + chunks = list(data.chunks) + chunks[1] = (output_shape[1], ) + else: + chunks = None predictions = da.map_blocks( mapped_predict, booster, @@ -1029,6 +1038,8 @@ async def _direct_predict_impl( False, columns, base_margin_array, + + chunks=chunks, drop_axis=drop_axis, new_axis=new_axis, dtype=numpy.float32, @@ -1776,20 +1787,20 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): self, X: _DaskCollection, validate_features: bool, - output_margin: bool, base_margin: Optional[_DaskCollection], iteration_range: Optional[Tuple[int, int]], ) -> _DaskCollection: - if iteration_range is None: - iteration_range = (0, 0) predts = await super()._predict_async( data=X, - output_margin=output_margin, + output_margin=self.objective == "multi:softmax", validate_features=validate_features, base_margin=base_margin, iteration_range=iteration_range, ) - return _cls_predict_proba(self.objective, predts, da.vstack) + vstack = update_wrapper( + partial(da.vstack, allow_unknown_chunksizes=True), da.vstack + ) + return _cls_predict_proba(getattr(self, "n_classes_", None), predts, vstack) # pylint: disable=missing-function-docstring def predict_proba( @@ -1797,7 +1808,6 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): X: _DaskCollection, ntree_limit: Optional[int] = None, validate_features: bool = True, - output_margin: bool = False, base_margin: Optional[_DaskCollection] = None, iteration_range: Optional[Tuple[int, int]] = None, ) -> Any: @@ -1808,7 +1818,6 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): self._predict_proba_async, X=X, validate_features=validate_features, - output_margin=output_margin, base_margin=base_margin, iteration_range=iteration_range, ) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index c27281517..86f342381 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -4,7 +4,7 @@ import copy import warnings import json -from typing import Union, Optional, List, Dict, Callable, Tuple, Any +from typing import Union, Optional, List, Dict, Callable, Tuple, Any, TypeVar import numpy as np from .core import Booster, DMatrix, XGBoostError from .core import _deprecate_positional_args, _convert_ntree_limit @@ -561,6 +561,8 @@ class XGBModel(XGBModelBase): self._Booster.load_model(fname) meta = self._Booster.attr('scikit_learn') if meta is None: + # FIXME(jiaming): This doesn't have to be a problem as most of the needed + # information like num_class and objective is in Learner class. warnings.warn( 'Loading a native XGBoost model with Scikit-Learn interface.') return @@ -571,6 +573,8 @@ class XGBModel(XGBModelBase): self._le = XGBoostLabelEncoder() self._le.from_json(v) continue + # FIXME(jiaming): This can be removed once label encoder is gone since we can + # generate it from `np.arange(self.n_classes_)` if k == 'classes_': self.classes_ = np.array(v) continue @@ -1024,17 +1028,14 @@ class XGBModel(XGBModelBase): return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias']) -def _cls_predict_proba( - objective: Union[str, Callable], prediction: Any, vstack: Callable -) -> Any: - if objective == 'multi:softmax': - raise ValueError('multi:softmax objective does not support predict_proba,' - ' use `multi:softprob` or `binary:logistic` instead.') - if objective == 'multi:softprob' or callable(objective): - # Return prediction directly if if objective is defined by user since we don't - # know how to perform the transformation +PredtT = TypeVar("PredtT") + + +def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) -> PredtT: + assert len(prediction.shape) <= 2 + if len(prediction.shape) == 2 and prediction.shape[1] == n_classes: return prediction - # Lastly the binary logistic function + # binary logistic function classone_probs = prediction classzero_probs = 1.0 - classone_probs return vstack((classzero_probs, classone_probs)).transpose() @@ -1218,8 +1219,10 @@ class XGBClassifier(XGBModel, XGBClassifierBase): return class_probs if len(class_probs.shape) > 1: + # turns softprob into softmax column_indexes = np.argmax(class_probs, axis=1) else: + # turns soft logit into class label column_indexes = np.repeat(0, class_probs.shape[0]) column_indexes[class_probs > 0.5] = 1 @@ -1262,15 +1265,23 @@ class XGBClassifier(XGBModel, XGBClassifierBase): a numpy array of shape array-like of shape (n_samples, n_classes) with the probability of each data example being of a given class. """ + # custom obj: Do nothing as we don't know what to do. + # softprob: Do nothing, output is proba. + # softmax: Use output margin to remove the argmax in PredTransform. + # binary:logistic: Expand the prob vector into 2-class matrix after predict. + # binary:logitraw: Unsupported by predict_proba() class_probs = super().predict( X=X, - output_margin=False, + output_margin=self.objective == "multi:softmax", ntree_limit=ntree_limit, validate_features=validate_features, base_margin=base_margin, iteration_range=iteration_range ) - return _cls_predict_proba(self.objective, class_probs, np.vstack) + # If model is loaded from a raw booster there's no `n_classes_` + return _cls_predict_proba( + getattr(self, "n_classes_", None), class_probs, np.vstack + ) def evals_result(self): """Return the evaluation results. diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index 9295cb46a..cfdd9db12 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -173,6 +173,7 @@ def run_gpu_hist( assert tm.non_increasing(history["train"][dataset.metric]) +@pytest.mark.skipif(**tm.no_cudf()) def test_boost_from_prediction(local_cuda_cluster: LocalCUDACluster) -> None: from sklearn.datasets import load_breast_cancer with Client(local_cuda_cluster) as client: @@ -201,6 +202,7 @@ class TestDistributedGPU: @settings(deadline=duration(seconds=120), suppress_health_check=suppress) @pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask_cuda()) + @pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.parametrize( "local_cuda_cluster", [{"n_workers": 2}], indirect=["local_cuda_cluster"] ) @@ -275,7 +277,7 @@ class TestDistributedGPU: X = dask_cudf.from_dask_dataframe(dd.from_dask_array(X_)) y = dask_cudf.from_dask_dataframe(dd.from_dask_array(y_)) w = dask_cudf.from_dask_dataframe(dd.from_dask_array(w_)) - run_dask_classifier(X, y, w, model, client) + run_dask_classifier(X, y, w, model, client, 10) @pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask_cuda()) @@ -453,6 +455,7 @@ async def run_from_dask_array_asyncio(scheduler_address: str) -> dxgb.TrainRetur @pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask_cuda()) +@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.mgpu def test_with_asyncio(local_cuda_cluster: LocalCUDACluster) -> None: with Client(local_cuda_cluster) as client: diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index c08584af2..f8403cdf0 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -318,14 +318,17 @@ def run_dask_classifier( w: xgb.dask._DaskCollection, model: str, client: "Client", + n_classes, ) -> None: + metric = "merror" if n_classes > 2 else "logloss" + if model == "boosting": classifier = xgb.dask.DaskXGBClassifier( - verbosity=1, n_estimators=2, eval_metric="merror" + verbosity=1, n_estimators=2, eval_metric=metric ) else: classifier = xgb.dask.DaskXGBRFClassifier( - verbosity=1, n_estimators=2, eval_metric="merror" + verbosity=1, n_estimators=2, eval_metric=metric ) assert classifier._estimator_type == "classifier" @@ -343,7 +346,7 @@ def run_dask_classifier( assert isinstance(history, dict) assert list(history.keys())[0] == "validation_0" - assert list(history["validation_0"].keys())[0] == "merror" + assert list(history["validation_0"].keys())[0] == metric assert len(list(history["validation_0"])) == 1 forest = int( json.loads(classifier.get_booster().save_config())["learner"][ @@ -351,34 +354,35 @@ def run_dask_classifier( ]["gbtree_train_param"]["num_parallel_tree"] ) if model == "boosting": - assert len(history["validation_0"]["merror"]) == 2 + assert len(history["validation_0"][metric]) == 2 assert forest == 1 else: - assert len(history["validation_0"]["merror"]) == 1 + assert len(history["validation_0"][metric]) == 1 assert forest == 2 # Test .predict_proba() probas = classifier.predict_proba(X).compute() - assert classifier.n_classes_ == 10 + assert classifier.n_classes_ == n_classes assert probas.ndim == 2 assert probas.shape[0] == kRows - assert probas.shape[1] == 10 + assert probas.shape[1] == n_classes - cls_booster = classifier.get_booster() - single_node_proba = cls_booster.inplace_predict(X.compute()) + if n_classes > 2: + cls_booster = classifier.get_booster() + single_node_proba = cls_booster.inplace_predict(X.compute()) - # test shared by CPU and GPU - if isinstance(single_node_proba, np.ndarray): - np.testing.assert_allclose(single_node_proba, probas) - else: - import cupy - cupy.testing.assert_allclose(single_node_proba, probas) + # test shared by CPU and GPU + if isinstance(single_node_proba, np.ndarray): + np.testing.assert_allclose(single_node_proba, probas) + else: + import cupy + cupy.testing.assert_allclose(single_node_proba, probas) # Test with dataframe, not shared with GPU as cupy doesn't work well with da.unique. - if isinstance(X, da.Array): + if isinstance(X, da.Array) and n_classes > 2: X_d: dd.DataFrame = X.to_dask_dataframe() - assert classifier.n_classes_ == 10 + assert classifier.n_classes_ == n_classes prediction_df = classifier.predict(X_d).compute() assert prediction_df.ndim == 1 @@ -393,7 +397,12 @@ def run_dask_classifier( def test_dask_classifier(model: str, client: "Client") -> None: X, y, w = generate_array(with_weights=True) y = (y * 10).astype(np.int32) - run_dask_classifier(X, y, w, model, client) + run_dask_classifier(X, y, w, model, client, 10) + + y_bin = y.copy() + y_bin[y > 5] = 1.0 + y_bin[y <= 5] = 0.0 + run_dask_classifier(X, y_bin, w, model, client, 2) @pytest.mark.skipif(**tm.no_sklearn())