From 87ab1ad607f9c15e643a089b0832afeebf0cc7d2 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 2 Feb 2021 08:45:52 +0800 Subject: [PATCH] [dask] Accept `Future` of model for prediction. (#6650) This PR changes predict and inplace_predict to accept a Future of model, to avoid sending models to workers repeatably. * Document is updated to reflect functionality additions in recent changes. --- doc/tutorials/dask.rst | 105 +++++++++++++++++++++++------- python-package/xgboost/dask.py | 79 ++++++++++++---------- python-package/xgboost/sklearn.py | 2 +- tests/python/test_with_dask.py | 42 ++++++------ 4 files changed, 150 insertions(+), 78 deletions(-) diff --git a/doc/tutorials/dask.rst b/doc/tutorials/dask.rst index d98a19aed..7530ad953 100644 --- a/doc/tutorials/dask.rst +++ b/doc/tutorials/dask.rst @@ -112,8 +112,11 @@ is a ``DaskDMatrix`` or ``da.Array``. When putting dask collection directly int ``predict`` function or using ``inplace_predict``, the output type depends on input data. See next section for details. -Alternatively, XGBoost also implements the Scikit-Learn interface with ``DaskXGBClassifier`` -and ``DaskXGBRegressor``. See ``xgboost/demo/dask`` for more examples. +Alternatively, XGBoost also implements the Scikit-Learn interface with +``DaskXGBClassifier``, ``DaskXGBRegressor``, ``DaskXGBRanker`` and 2 random forest +variances. This wrapper is similar to the single node Scikit-Learn interface in xgboost, +with dask collection as inputs and has an additional ``client`` attribute. See +``xgboost/demo/dask`` for more examples. ****************** @@ -160,6 +163,32 @@ if not using GPU, the number of threads used for prediction on each block matter now, xgboost uses single thread for each partition. If the number of blocks on each workers is smaller than number of cores, then the CPU workers might not be fully utilized. +One simple optimization for running consecutive predictions is using +``distributed.Future``: + +.. code-block:: python + + dataset = [X_0, X_1, X_2] + booster_f = client.scatter(booster, broadcast=True) + futures = [] + for X in dataset: + # Here we pass in a future instead of concrete booster + shap_f = xgb.dask.predict(client, booster_f, X, pred_contribs=True) + futures.append(shap_f) + + results = client.gather(futures) + + +This is only available on functional interface, as the Scikit-Learn wrapper doesn't know +how to maintain a valid future for booster. To obtain the booster object from +Scikit-Learn wrapper object: + +.. code-block:: python + + cls = xgb.dask.DaskXGBClassifier() + cls.fit(X, y) + + booster = cls.get_booster() *************************** @@ -231,17 +260,17 @@ will override the configuration in Dask. For example: with dask.distributed.LocalCluster(n_workers=7, threads_per_worker=4) as cluster: There are 4 threads allocated for each dask worker. Then by default XGBoost will use 4 -threads in each process for both training and prediction. But if ``nthread`` parameter is -set: +threads in each process for training. But if ``nthread`` parameter is set: .. code-block:: python - output = xgb.dask.train(client, - {'verbosity': 1, - 'nthread': 8, - 'tree_method': 'hist'}, - dtrain, - num_boost_round=4, evals=[(dtrain, 'train')]) + output = xgb.dask.train( + client, + {"verbosity": 1, "nthread": 8, "tree_method": "hist"}, + dtrain, + num_boost_round=4, + evals=[(dtrain, "train")], + ) XGBoost will use 8 threads in each training process. @@ -274,12 +303,12 @@ Functional interface: with_X = await xgb.dask.predict(client, output, X) inplace = await xgb.dask.inplace_predict(client, output, X) - # Use `client.compute` instead of the `compute` method from dask collection + # Use ``client.compute`` instead of the ``compute`` method from dask collection print(await client.compute(with_m)) While for the Scikit-Learn interface, trivial methods like ``set_params`` and accessing class -attributes like ``evals_result_`` do not require ``await``. Other methods involving +attributes like ``evals_result()`` do not require ``await``. Other methods involving actual computation will return a coroutine and hence require awaiting: .. code-block:: python @@ -373,6 +402,46 @@ If early stopping is enabled by also passing ``early_stopping_rounds``, you can print(booster.best_iteration) best_model = booster[: booster.best_iteration] + +******************* +Other customization +******************* + +XGBoost dask interface accepts other advanced features found in single node Python +interface, including callback functions, custom evaluation metric and objective: + + def eval_error_metric(predt, dtrain: xgb.DMatrix): + label = dtrain.get_label() + r = np.zeros(predt.shape) + gt = predt > 0.5 + r[gt] = 1 - label[gt] + le = predt <= 0.5 + r[le] = label[le] + return 'CustomErr', np.sum(r) + + # custom callback + early_stop = xgb.callback.EarlyStopping( + rounds=early_stopping_rounds, + metric_name="CustomErr", + data_name="Train", + save_best=True, + ) + + booster = xgb.dask.train( + client, + params={ + "objective": "binary:logistic", + "eval_metric": ["error", "rmse"], + "tree_method": "hist", + }, + dtrain=D_train, + evals=[(D_train, "Train"), (D_valid, "Valid")], + feval=eval_error_metric, # custom evaluation metric + num_boost_round=100, + callbacks=[early_stop], + ) + + ***************************************************************************** Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors ***************************************************************************** @@ -414,15 +483,3 @@ References: #. https://github.com/dask/dask/issues/6833 #. https://stackoverflow.com/questions/45941528/how-to-efficiently-send-a-large-numpy-array-to-the-cluster-with-dask-array - -*********** -Limitations -*********** - -Basic functionality including model training and generating classification and regression predictions -have been implemented. However, there are still some other limitations we haven't -addressed yet: - -- Label encoding for the ``DaskXGBClassifier`` classifier may not be supported. So users need - to encode their training labels into discrete values first. -- Ranking is not yet supported. diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 5a9c334e8..1db498045 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -940,16 +940,14 @@ def _can_output_df(data: _DaskCollection, output_shape: Tuple) -> bool: async def _direct_predict_impl( - client: "distributed.Client", mapped_predict: Callable, - booster: Booster, + booster: "distributed.Future", data: _DaskCollection, base_margin: Optional[_DaskCollection], output_shape: Tuple[int, ...], meta: Dict[int, str], ) -> _DaskCollection: columns = list(meta.keys()) - booster_f = await client.scatter(data=booster, broadcast=True) if _can_output_df(data, output_shape): if base_margin is not None and isinstance(base_margin, da.Array): base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe() @@ -957,7 +955,7 @@ async def _direct_predict_impl( base_margin_df = base_margin predictions = dd.map_partitions( mapped_predict, - booster_f, + booster, data, True, columns, @@ -984,7 +982,7 @@ async def _direct_predict_impl( new_axis = [i + 2 for i in range(len(output_shape) - 2)] predictions = da.map_blocks( mapped_predict, - booster_f, + booster, data, False, columns, @@ -997,7 +995,10 @@ async def _direct_predict_impl( def _infer_predict_output( - booster: Booster, data: _DaskCollection, inplace: bool, **kwargs: Any + booster: Booster, + data: Union[DaskDMatrix, _DaskCollection], + inplace: bool, + **kwargs: Any ) -> Tuple[Tuple[int, ...], Dict[int, str]]: """Create a dummy test sample to infer output shape for prediction.""" if isinstance(data, DaskDMatrix): @@ -1021,11 +1022,29 @@ def _infer_predict_output( return test_predt.shape, meta +async def _get_model_future( + client: "distributed.Client", model: Union[Booster, Dict, "distributed.Future"] +) -> "distributed.Future": + if isinstance(model, Booster): + booster = await client.scatter(model, broadcast=True) + elif isinstance(model, dict): + booster = await client.scatter(model["booster"]) + elif isinstance(model, distributed.Future): + booster = model + if booster.type is not Booster: + raise TypeError( + f"Underlying type of model future should be `Booster`, got {booster.type}" + ) + else: + raise TypeError(_expect([Booster, dict, distributed.Future], type(model))) + return booster + + # pylint: disable=too-many-statements async def _predict_async( client: "distributed.Client", global_config: Dict[str, Any], - model: Union[Booster, Dict], + model: Union[Booster, Dict, "distributed.Future"], data: _DaskCollection, output_margin: bool, missing: float, @@ -1035,12 +1054,7 @@ async def _predict_async( pred_interactions: bool, validate_features: bool, ) -> _DaskCollection: - if isinstance(model, Booster): - _booster = model - elif isinstance(model, dict): - _booster = model["booster"] - else: - raise TypeError(_expect([Booster, dict], type(model))) + _booster = await _get_model_future(client, model) if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)): raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data))) @@ -1070,7 +1084,7 @@ async def _predict_async( # Predict on dask collection directly. if isinstance(data, (da.Array, dd.DataFrame)): _output_shape, meta = _infer_predict_output( - _booster, + await _booster.result(), data, inplace=False, output_margin=output_margin, @@ -1081,10 +1095,11 @@ async def _predict_async( validate_features=False, ) return await _direct_predict_impl( - client, mapped_predict, _booster, data, None, _output_shape, meta + mapped_predict, _booster, data, None, _output_shape, meta ) + output_shape, _ = _infer_predict_output( - booster=_booster, + booster=await _booster.result(), data=data, inplace=False, output_margin=output_margin, @@ -1108,11 +1123,9 @@ async def _predict_async( for i, blob in enumerate(part[1:]): if meta_names[i] == "base_margin": base_margin = blob - worker = distributed.get_worker() with config.config_context(**global_config): m = DMatrix( data, - nthread=worker.nthreads, missing=missing, base_margin=base_margin, feature_names=feature_names, @@ -1148,9 +1161,8 @@ async def _predict_async( all_shapes = [shape for part, shape, order in parts_with_order] futures = [] - booster_f = await client.scatter(data=_booster, broadcast=True) for part in all_parts: - f = client.submit(dispatched_predict, booster_f, part) + f = client.submit(dispatched_predict, _booster, part) futures.append(f) # Constructing a dask array from list of numpy arrays @@ -1168,7 +1180,7 @@ async def _predict_async( def predict( # pylint: disable=unused-argument client: "distributed.Client", - model: Union[TrainReturnT, Booster], + model: Union[TrainReturnT, Booster, "distributed.Future"], data: Union[DaskDMatrix, _DaskCollection], output_margin: bool = False, missing: float = numpy.nan, @@ -1194,7 +1206,8 @@ def predict( # pylint: disable=unused-argument Specify the dask client used for training. Use default client returned from dask if it's set to None. model: - The trained model. + The trained model. It can be a distributed.Future so user can + pre-scatter it onto all workers. data: Input data used for prediction. When input is a dataframe object, prediction output is a series. @@ -1221,19 +1234,14 @@ def predict( # pylint: disable=unused-argument async def _inplace_predict_async( client: "distributed.Client", global_config: Dict[str, Any], - model: Union[Booster, Dict], + model: Union[Booster, Dict, "distributed.Future"], data: _DaskCollection, iteration_range: Tuple[int, int] = (0, 0), predict_type: str = 'value', missing: float = numpy.nan ) -> _DaskCollection: client = _xgb_get_client(client) - if isinstance(model, Booster): - booster = model - elif isinstance(model, dict): - booster = model['booster'] - else: - raise TypeError(_expect([Booster, dict], type(model))) + booster = await _get_model_future(client, model) if not isinstance(data, (da.Array, dd.DataFrame)): raise TypeError(_expect([da.Array, dd.DataFrame], type(data))) @@ -1261,16 +1269,20 @@ async def _inplace_predict_async( return prediction shape, meta = _infer_predict_output( - booster, data, True, predict_type=predict_type, iteration_range=iteration_range + await booster.result(), + data, + True, + predict_type=predict_type, + iteration_range=iteration_range ) return await _direct_predict_impl( - client, mapped_predict, booster, data, None, shape, meta + mapped_predict, booster, data, None, shape, meta ) def inplace_predict( # pylint: disable=unused-argument client: "distributed.Client", - model: Union[TrainReturnT, Booster], + model: Union[TrainReturnT, Booster, "distributed.Future"], data: _DaskCollection, iteration_range: Tuple[int, int] = (0, 0), predict_type: str = 'value', @@ -1286,7 +1298,8 @@ def inplace_predict( # pylint: disable=unused-argument Specify the dask client used for training. Use default client returned from dask if it's set to None. model: - The trained model. + The trained model. It can be a distributed.Future so user can + pre-scatter it onto all workers. iteration_range: Specify the range of trees used for prediction. predict_type: diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 398fd63ef..3796b43ea 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -535,7 +535,7 @@ class XGBModel(XGBModelBase): json.dumps({k: v}) meta[k] = v except TypeError: - warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.') + warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.', UserWarning) meta['_estimator_type'] = self._get_type() meta_str = json.dumps(meta) self.get_booster().set_attr(scikit_learn=meta_str) diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 43b0c33b5..e5481bf3b 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -608,28 +608,30 @@ def test_with_asyncio() -> None: asyncio.run(run_dask_classifier_asyncio(address)) -def test_predict() -> None: - with LocalCluster(n_workers=kWorkers) as cluster: - with Client(cluster) as client: - X, y, _ = generate_array() - dtrain = DaskDMatrix(client, X, y) - booster = xgb.dask.train( - client, {}, dtrain, num_boost_round=2)['booster'] +def test_predict(client: "Client") -> None: + X, y, _ = generate_array() + dtrain = DaskDMatrix(client, X, y) + booster = xgb.dask.train(client, {}, dtrain, num_boost_round=2)["booster"] - pred = xgb.dask.predict(client, model=booster, data=dtrain) - assert pred.ndim == 1 - assert pred.shape[0] == kRows + predt_0 = xgb.dask.predict(client, model=booster, data=dtrain) + assert predt_0.ndim == 1 + assert predt_0.shape[0] == kRows - margin = xgb.dask.predict(client, model=booster, data=dtrain, - output_margin=True) - assert margin.ndim == 1 - assert margin.shape[0] == kRows + margin = xgb.dask.predict(client, model=booster, data=dtrain, output_margin=True) + assert margin.ndim == 1 + assert margin.shape[0] == kRows - shap = xgb.dask.predict(client, model=booster, data=dtrain, - pred_contribs=True) - assert shap.ndim == 2 - assert shap.shape[0] == kRows - assert shap.shape[1] == kCols + 1 + shap = xgb.dask.predict(client, model=booster, data=dtrain, pred_contribs=True) + assert shap.ndim == 2 + assert shap.shape[0] == kRows + assert shap.shape[1] == kCols + 1 + + booster_f = client.scatter(booster, broadcast=True) + + predt_1 = xgb.dask.predict(client, booster_f, X).compute() + predt_2 = xgb.dask.inplace_predict(client, booster_f, X).compute() + np.testing.assert_allclose(predt_0, predt_1) + np.testing.assert_allclose(predt_0, predt_2) def test_predict_with_meta(client: "Client") -> None: @@ -1034,7 +1036,7 @@ class TestWithDask: rows = X.shape[0] cols = X.shape[1] - def assert_shape(shape): + def assert_shape(shape: Tuple[int, ...]) -> None: assert shape[0] == rows if "num_class" in params.keys(): assert shape[1] == params["num_class"]