[dask] Accept Future of model for prediction. (#6650)

This PR changes predict and inplace_predict to accept a Future of model, to avoid sending models to workers repeatably.

* Document is updated to reflect functionality additions in recent changes.
This commit is contained in:
Jiaming Yuan 2021-02-02 08:45:52 +08:00 committed by GitHub
parent a9ec0ea6da
commit 87ab1ad607
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 150 additions and 78 deletions

View File

@ -112,8 +112,11 @@ is a ``DaskDMatrix`` or ``da.Array``. When putting dask collection directly int
``predict`` function or using ``inplace_predict``, the output type depends on input data. ``predict`` function or using ``inplace_predict``, the output type depends on input data.
See next section for details. See next section for details.
Alternatively, XGBoost also implements the Scikit-Learn interface with ``DaskXGBClassifier`` Alternatively, XGBoost also implements the Scikit-Learn interface with
and ``DaskXGBRegressor``. See ``xgboost/demo/dask`` for more examples. ``DaskXGBClassifier``, ``DaskXGBRegressor``, ``DaskXGBRanker`` and 2 random forest
variances. This wrapper is similar to the single node Scikit-Learn interface in xgboost,
with dask collection as inputs and has an additional ``client`` attribute. See
``xgboost/demo/dask`` for more examples.
****************** ******************
@ -160,6 +163,32 @@ if not using GPU, the number of threads used for prediction on each block matter
now, xgboost uses single thread for each partition. If the number of blocks on each now, xgboost uses single thread for each partition. If the number of blocks on each
workers is smaller than number of cores, then the CPU workers might not be fully utilized. workers is smaller than number of cores, then the CPU workers might not be fully utilized.
One simple optimization for running consecutive predictions is using
``distributed.Future``:
.. code-block:: python
dataset = [X_0, X_1, X_2]
booster_f = client.scatter(booster, broadcast=True)
futures = []
for X in dataset:
# Here we pass in a future instead of concrete booster
shap_f = xgb.dask.predict(client, booster_f, X, pred_contribs=True)
futures.append(shap_f)
results = client.gather(futures)
This is only available on functional interface, as the Scikit-Learn wrapper doesn't know
how to maintain a valid future for booster. To obtain the booster object from
Scikit-Learn wrapper object:
.. code-block:: python
cls = xgb.dask.DaskXGBClassifier()
cls.fit(X, y)
booster = cls.get_booster()
*************************** ***************************
@ -231,17 +260,17 @@ will override the configuration in Dask. For example:
with dask.distributed.LocalCluster(n_workers=7, threads_per_worker=4) as cluster: with dask.distributed.LocalCluster(n_workers=7, threads_per_worker=4) as cluster:
There are 4 threads allocated for each dask worker. Then by default XGBoost will use 4 There are 4 threads allocated for each dask worker. Then by default XGBoost will use 4
threads in each process for both training and prediction. But if ``nthread`` parameter is threads in each process for training. But if ``nthread`` parameter is set:
set:
.. code-block:: python .. code-block:: python
output = xgb.dask.train(client, output = xgb.dask.train(
{'verbosity': 1, client,
'nthread': 8, {"verbosity": 1, "nthread": 8, "tree_method": "hist"},
'tree_method': 'hist'}, dtrain,
dtrain, num_boost_round=4,
num_boost_round=4, evals=[(dtrain, 'train')]) evals=[(dtrain, "train")],
)
XGBoost will use 8 threads in each training process. XGBoost will use 8 threads in each training process.
@ -274,12 +303,12 @@ Functional interface:
with_X = await xgb.dask.predict(client, output, X) with_X = await xgb.dask.predict(client, output, X)
inplace = await xgb.dask.inplace_predict(client, output, X) inplace = await xgb.dask.inplace_predict(client, output, X)
# Use `client.compute` instead of the `compute` method from dask collection # Use ``client.compute`` instead of the ``compute`` method from dask collection
print(await client.compute(with_m)) print(await client.compute(with_m))
While for the Scikit-Learn interface, trivial methods like ``set_params`` and accessing class While for the Scikit-Learn interface, trivial methods like ``set_params`` and accessing class
attributes like ``evals_result_`` do not require ``await``. Other methods involving attributes like ``evals_result()`` do not require ``await``. Other methods involving
actual computation will return a coroutine and hence require awaiting: actual computation will return a coroutine and hence require awaiting:
.. code-block:: python .. code-block:: python
@ -373,6 +402,46 @@ If early stopping is enabled by also passing ``early_stopping_rounds``, you can
print(booster.best_iteration) print(booster.best_iteration)
best_model = booster[: booster.best_iteration] best_model = booster[: booster.best_iteration]
*******************
Other customization
*******************
XGBoost dask interface accepts other advanced features found in single node Python
interface, including callback functions, custom evaluation metric and objective:
def eval_error_metric(predt, dtrain: xgb.DMatrix):
label = dtrain.get_label()
r = np.zeros(predt.shape)
gt = predt > 0.5
r[gt] = 1 - label[gt]
le = predt <= 0.5
r[le] = label[le]
return 'CustomErr', np.sum(r)
# custom callback
early_stop = xgb.callback.EarlyStopping(
rounds=early_stopping_rounds,
metric_name="CustomErr",
data_name="Train",
save_best=True,
)
booster = xgb.dask.train(
client,
params={
"objective": "binary:logistic",
"eval_metric": ["error", "rmse"],
"tree_method": "hist",
},
dtrain=D_train,
evals=[(D_train, "Train"), (D_valid, "Valid")],
feval=eval_error_metric, # custom evaluation metric
num_boost_round=100,
callbacks=[early_stop],
)
***************************************************************************** *****************************************************************************
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
***************************************************************************** *****************************************************************************
@ -414,15 +483,3 @@ References:
#. https://github.com/dask/dask/issues/6833 #. https://github.com/dask/dask/issues/6833
#. https://stackoverflow.com/questions/45941528/how-to-efficiently-send-a-large-numpy-array-to-the-cluster-with-dask-array #. https://stackoverflow.com/questions/45941528/how-to-efficiently-send-a-large-numpy-array-to-the-cluster-with-dask-array
***********
Limitations
***********
Basic functionality including model training and generating classification and regression predictions
have been implemented. However, there are still some other limitations we haven't
addressed yet:
- Label encoding for the ``DaskXGBClassifier`` classifier may not be supported. So users need
to encode their training labels into discrete values first.
- Ranking is not yet supported.

View File

@ -940,16 +940,14 @@ def _can_output_df(data: _DaskCollection, output_shape: Tuple) -> bool:
async def _direct_predict_impl( async def _direct_predict_impl(
client: "distributed.Client",
mapped_predict: Callable, mapped_predict: Callable,
booster: Booster, booster: "distributed.Future",
data: _DaskCollection, data: _DaskCollection,
base_margin: Optional[_DaskCollection], base_margin: Optional[_DaskCollection],
output_shape: Tuple[int, ...], output_shape: Tuple[int, ...],
meta: Dict[int, str], meta: Dict[int, str],
) -> _DaskCollection: ) -> _DaskCollection:
columns = list(meta.keys()) columns = list(meta.keys())
booster_f = await client.scatter(data=booster, broadcast=True)
if _can_output_df(data, output_shape): if _can_output_df(data, output_shape):
if base_margin is not None and isinstance(base_margin, da.Array): if base_margin is not None and isinstance(base_margin, da.Array):
base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe() base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe()
@ -957,7 +955,7 @@ async def _direct_predict_impl(
base_margin_df = base_margin base_margin_df = base_margin
predictions = dd.map_partitions( predictions = dd.map_partitions(
mapped_predict, mapped_predict,
booster_f, booster,
data, data,
True, True,
columns, columns,
@ -984,7 +982,7 @@ async def _direct_predict_impl(
new_axis = [i + 2 for i in range(len(output_shape) - 2)] new_axis = [i + 2 for i in range(len(output_shape) - 2)]
predictions = da.map_blocks( predictions = da.map_blocks(
mapped_predict, mapped_predict,
booster_f, booster,
data, data,
False, False,
columns, columns,
@ -997,7 +995,10 @@ async def _direct_predict_impl(
def _infer_predict_output( def _infer_predict_output(
booster: Booster, data: _DaskCollection, inplace: bool, **kwargs: Any booster: Booster,
data: Union[DaskDMatrix, _DaskCollection],
inplace: bool,
**kwargs: Any
) -> Tuple[Tuple[int, ...], Dict[int, str]]: ) -> Tuple[Tuple[int, ...], Dict[int, str]]:
"""Create a dummy test sample to infer output shape for prediction.""" """Create a dummy test sample to infer output shape for prediction."""
if isinstance(data, DaskDMatrix): if isinstance(data, DaskDMatrix):
@ -1021,11 +1022,29 @@ def _infer_predict_output(
return test_predt.shape, meta return test_predt.shape, meta
async def _get_model_future(
client: "distributed.Client", model: Union[Booster, Dict, "distributed.Future"]
) -> "distributed.Future":
if isinstance(model, Booster):
booster = await client.scatter(model, broadcast=True)
elif isinstance(model, dict):
booster = await client.scatter(model["booster"])
elif isinstance(model, distributed.Future):
booster = model
if booster.type is not Booster:
raise TypeError(
f"Underlying type of model future should be `Booster`, got {booster.type}"
)
else:
raise TypeError(_expect([Booster, dict, distributed.Future], type(model)))
return booster
# pylint: disable=too-many-statements # pylint: disable=too-many-statements
async def _predict_async( async def _predict_async(
client: "distributed.Client", client: "distributed.Client",
global_config: Dict[str, Any], global_config: Dict[str, Any],
model: Union[Booster, Dict], model: Union[Booster, Dict, "distributed.Future"],
data: _DaskCollection, data: _DaskCollection,
output_margin: bool, output_margin: bool,
missing: float, missing: float,
@ -1035,12 +1054,7 @@ async def _predict_async(
pred_interactions: bool, pred_interactions: bool,
validate_features: bool, validate_features: bool,
) -> _DaskCollection: ) -> _DaskCollection:
if isinstance(model, Booster): _booster = await _get_model_future(client, model)
_booster = model
elif isinstance(model, dict):
_booster = model["booster"]
else:
raise TypeError(_expect([Booster, dict], type(model)))
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)): if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data))) raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data)))
@ -1070,7 +1084,7 @@ async def _predict_async(
# Predict on dask collection directly. # Predict on dask collection directly.
if isinstance(data, (da.Array, dd.DataFrame)): if isinstance(data, (da.Array, dd.DataFrame)):
_output_shape, meta = _infer_predict_output( _output_shape, meta = _infer_predict_output(
_booster, await _booster.result(),
data, data,
inplace=False, inplace=False,
output_margin=output_margin, output_margin=output_margin,
@ -1081,10 +1095,11 @@ async def _predict_async(
validate_features=False, validate_features=False,
) )
return await _direct_predict_impl( return await _direct_predict_impl(
client, mapped_predict, _booster, data, None, _output_shape, meta mapped_predict, _booster, data, None, _output_shape, meta
) )
output_shape, _ = _infer_predict_output( output_shape, _ = _infer_predict_output(
booster=_booster, booster=await _booster.result(),
data=data, data=data,
inplace=False, inplace=False,
output_margin=output_margin, output_margin=output_margin,
@ -1108,11 +1123,9 @@ async def _predict_async(
for i, blob in enumerate(part[1:]): for i, blob in enumerate(part[1:]):
if meta_names[i] == "base_margin": if meta_names[i] == "base_margin":
base_margin = blob base_margin = blob
worker = distributed.get_worker()
with config.config_context(**global_config): with config.config_context(**global_config):
m = DMatrix( m = DMatrix(
data, data,
nthread=worker.nthreads,
missing=missing, missing=missing,
base_margin=base_margin, base_margin=base_margin,
feature_names=feature_names, feature_names=feature_names,
@ -1148,9 +1161,8 @@ async def _predict_async(
all_shapes = [shape for part, shape, order in parts_with_order] all_shapes = [shape for part, shape, order in parts_with_order]
futures = [] futures = []
booster_f = await client.scatter(data=_booster, broadcast=True)
for part in all_parts: for part in all_parts:
f = client.submit(dispatched_predict, booster_f, part) f = client.submit(dispatched_predict, _booster, part)
futures.append(f) futures.append(f)
# Constructing a dask array from list of numpy arrays # Constructing a dask array from list of numpy arrays
@ -1168,7 +1180,7 @@ async def _predict_async(
def predict( # pylint: disable=unused-argument def predict( # pylint: disable=unused-argument
client: "distributed.Client", client: "distributed.Client",
model: Union[TrainReturnT, Booster], model: Union[TrainReturnT, Booster, "distributed.Future"],
data: Union[DaskDMatrix, _DaskCollection], data: Union[DaskDMatrix, _DaskCollection],
output_margin: bool = False, output_margin: bool = False,
missing: float = numpy.nan, missing: float = numpy.nan,
@ -1194,7 +1206,8 @@ def predict( # pylint: disable=unused-argument
Specify the dask client used for training. Use default client Specify the dask client used for training. Use default client
returned from dask if it's set to None. returned from dask if it's set to None.
model: model:
The trained model. The trained model. It can be a distributed.Future so user can
pre-scatter it onto all workers.
data: data:
Input data used for prediction. When input is a dataframe object, Input data used for prediction. When input is a dataframe object,
prediction output is a series. prediction output is a series.
@ -1221,19 +1234,14 @@ def predict( # pylint: disable=unused-argument
async def _inplace_predict_async( async def _inplace_predict_async(
client: "distributed.Client", client: "distributed.Client",
global_config: Dict[str, Any], global_config: Dict[str, Any],
model: Union[Booster, Dict], model: Union[Booster, Dict, "distributed.Future"],
data: _DaskCollection, data: _DaskCollection,
iteration_range: Tuple[int, int] = (0, 0), iteration_range: Tuple[int, int] = (0, 0),
predict_type: str = 'value', predict_type: str = 'value',
missing: float = numpy.nan missing: float = numpy.nan
) -> _DaskCollection: ) -> _DaskCollection:
client = _xgb_get_client(client) client = _xgb_get_client(client)
if isinstance(model, Booster): booster = await _get_model_future(client, model)
booster = model
elif isinstance(model, dict):
booster = model['booster']
else:
raise TypeError(_expect([Booster, dict], type(model)))
if not isinstance(data, (da.Array, dd.DataFrame)): if not isinstance(data, (da.Array, dd.DataFrame)):
raise TypeError(_expect([da.Array, dd.DataFrame], type(data))) raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
@ -1261,16 +1269,20 @@ async def _inplace_predict_async(
return prediction return prediction
shape, meta = _infer_predict_output( shape, meta = _infer_predict_output(
booster, data, True, predict_type=predict_type, iteration_range=iteration_range await booster.result(),
data,
True,
predict_type=predict_type,
iteration_range=iteration_range
) )
return await _direct_predict_impl( return await _direct_predict_impl(
client, mapped_predict, booster, data, None, shape, meta mapped_predict, booster, data, None, shape, meta
) )
def inplace_predict( # pylint: disable=unused-argument def inplace_predict( # pylint: disable=unused-argument
client: "distributed.Client", client: "distributed.Client",
model: Union[TrainReturnT, Booster], model: Union[TrainReturnT, Booster, "distributed.Future"],
data: _DaskCollection, data: _DaskCollection,
iteration_range: Tuple[int, int] = (0, 0), iteration_range: Tuple[int, int] = (0, 0),
predict_type: str = 'value', predict_type: str = 'value',
@ -1286,7 +1298,8 @@ def inplace_predict( # pylint: disable=unused-argument
Specify the dask client used for training. Use default client Specify the dask client used for training. Use default client
returned from dask if it's set to None. returned from dask if it's set to None.
model: model:
The trained model. The trained model. It can be a distributed.Future so user can
pre-scatter it onto all workers.
iteration_range: iteration_range:
Specify the range of trees used for prediction. Specify the range of trees used for prediction.
predict_type: predict_type:

View File

@ -535,7 +535,7 @@ class XGBModel(XGBModelBase):
json.dumps({k: v}) json.dumps({k: v})
meta[k] = v meta[k] = v
except TypeError: except TypeError:
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.') warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.', UserWarning)
meta['_estimator_type'] = self._get_type() meta['_estimator_type'] = self._get_type()
meta_str = json.dumps(meta) meta_str = json.dumps(meta)
self.get_booster().set_attr(scikit_learn=meta_str) self.get_booster().set_attr(scikit_learn=meta_str)

View File

@ -608,28 +608,30 @@ def test_with_asyncio() -> None:
asyncio.run(run_dask_classifier_asyncio(address)) asyncio.run(run_dask_classifier_asyncio(address))
def test_predict() -> None: def test_predict(client: "Client") -> None:
with LocalCluster(n_workers=kWorkers) as cluster: X, y, _ = generate_array()
with Client(cluster) as client: dtrain = DaskDMatrix(client, X, y)
X, y, _ = generate_array() booster = xgb.dask.train(client, {}, dtrain, num_boost_round=2)["booster"]
dtrain = DaskDMatrix(client, X, y)
booster = xgb.dask.train(
client, {}, dtrain, num_boost_round=2)['booster']
pred = xgb.dask.predict(client, model=booster, data=dtrain) predt_0 = xgb.dask.predict(client, model=booster, data=dtrain)
assert pred.ndim == 1 assert predt_0.ndim == 1
assert pred.shape[0] == kRows assert predt_0.shape[0] == kRows
margin = xgb.dask.predict(client, model=booster, data=dtrain, margin = xgb.dask.predict(client, model=booster, data=dtrain, output_margin=True)
output_margin=True) assert margin.ndim == 1
assert margin.ndim == 1 assert margin.shape[0] == kRows
assert margin.shape[0] == kRows
shap = xgb.dask.predict(client, model=booster, data=dtrain, shap = xgb.dask.predict(client, model=booster, data=dtrain, pred_contribs=True)
pred_contribs=True) assert shap.ndim == 2
assert shap.ndim == 2 assert shap.shape[0] == kRows
assert shap.shape[0] == kRows assert shap.shape[1] == kCols + 1
assert shap.shape[1] == kCols + 1
booster_f = client.scatter(booster, broadcast=True)
predt_1 = xgb.dask.predict(client, booster_f, X).compute()
predt_2 = xgb.dask.inplace_predict(client, booster_f, X).compute()
np.testing.assert_allclose(predt_0, predt_1)
np.testing.assert_allclose(predt_0, predt_2)
def test_predict_with_meta(client: "Client") -> None: def test_predict_with_meta(client: "Client") -> None:
@ -1034,7 +1036,7 @@ class TestWithDask:
rows = X.shape[0] rows = X.shape[0]
cols = X.shape[1] cols = X.shape[1]
def assert_shape(shape): def assert_shape(shape: Tuple[int, ...]) -> None:
assert shape[0] == rows assert shape[0] == rows
if "num_class" in params.keys(): if "num_class" in params.keys():
assert shape[1] == params["num_class"] assert shape[1] == params["num_class"]