[dask] Accept Future of model for prediction. (#6650)

This PR changes predict and inplace_predict to accept a Future of model, to avoid sending models to workers repeatably.

* Document is updated to reflect functionality additions in recent changes.
This commit is contained in:
Jiaming Yuan 2021-02-02 08:45:52 +08:00 committed by GitHub
parent a9ec0ea6da
commit 87ab1ad607
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 150 additions and 78 deletions

View File

@ -112,8 +112,11 @@ is a ``DaskDMatrix`` or ``da.Array``. When putting dask collection directly int
``predict`` function or using ``inplace_predict``, the output type depends on input data.
See next section for details.
Alternatively, XGBoost also implements the Scikit-Learn interface with ``DaskXGBClassifier``
and ``DaskXGBRegressor``. See ``xgboost/demo/dask`` for more examples.
Alternatively, XGBoost also implements the Scikit-Learn interface with
``DaskXGBClassifier``, ``DaskXGBRegressor``, ``DaskXGBRanker`` and 2 random forest
variances. This wrapper is similar to the single node Scikit-Learn interface in xgboost,
with dask collection as inputs and has an additional ``client`` attribute. See
``xgboost/demo/dask`` for more examples.
******************
@ -160,6 +163,32 @@ if not using GPU, the number of threads used for prediction on each block matter
now, xgboost uses single thread for each partition. If the number of blocks on each
workers is smaller than number of cores, then the CPU workers might not be fully utilized.
One simple optimization for running consecutive predictions is using
``distributed.Future``:
.. code-block:: python
dataset = [X_0, X_1, X_2]
booster_f = client.scatter(booster, broadcast=True)
futures = []
for X in dataset:
# Here we pass in a future instead of concrete booster
shap_f = xgb.dask.predict(client, booster_f, X, pred_contribs=True)
futures.append(shap_f)
results = client.gather(futures)
This is only available on functional interface, as the Scikit-Learn wrapper doesn't know
how to maintain a valid future for booster. To obtain the booster object from
Scikit-Learn wrapper object:
.. code-block:: python
cls = xgb.dask.DaskXGBClassifier()
cls.fit(X, y)
booster = cls.get_booster()
***************************
@ -231,17 +260,17 @@ will override the configuration in Dask. For example:
with dask.distributed.LocalCluster(n_workers=7, threads_per_worker=4) as cluster:
There are 4 threads allocated for each dask worker. Then by default XGBoost will use 4
threads in each process for both training and prediction. But if ``nthread`` parameter is
set:
threads in each process for training. But if ``nthread`` parameter is set:
.. code-block:: python
output = xgb.dask.train(client,
{'verbosity': 1,
'nthread': 8,
'tree_method': 'hist'},
output = xgb.dask.train(
client,
{"verbosity": 1, "nthread": 8, "tree_method": "hist"},
dtrain,
num_boost_round=4, evals=[(dtrain, 'train')])
num_boost_round=4,
evals=[(dtrain, "train")],
)
XGBoost will use 8 threads in each training process.
@ -274,12 +303,12 @@ Functional interface:
with_X = await xgb.dask.predict(client, output, X)
inplace = await xgb.dask.inplace_predict(client, output, X)
# Use `client.compute` instead of the `compute` method from dask collection
# Use ``client.compute`` instead of the ``compute`` method from dask collection
print(await client.compute(with_m))
While for the Scikit-Learn interface, trivial methods like ``set_params`` and accessing class
attributes like ``evals_result_`` do not require ``await``. Other methods involving
attributes like ``evals_result()`` do not require ``await``. Other methods involving
actual computation will return a coroutine and hence require awaiting:
.. code-block:: python
@ -373,6 +402,46 @@ If early stopping is enabled by also passing ``early_stopping_rounds``, you can
print(booster.best_iteration)
best_model = booster[: booster.best_iteration]
*******************
Other customization
*******************
XGBoost dask interface accepts other advanced features found in single node Python
interface, including callback functions, custom evaluation metric and objective:
def eval_error_metric(predt, dtrain: xgb.DMatrix):
label = dtrain.get_label()
r = np.zeros(predt.shape)
gt = predt > 0.5
r[gt] = 1 - label[gt]
le = predt <= 0.5
r[le] = label[le]
return 'CustomErr', np.sum(r)
# custom callback
early_stop = xgb.callback.EarlyStopping(
rounds=early_stopping_rounds,
metric_name="CustomErr",
data_name="Train",
save_best=True,
)
booster = xgb.dask.train(
client,
params={
"objective": "binary:logistic",
"eval_metric": ["error", "rmse"],
"tree_method": "hist",
},
dtrain=D_train,
evals=[(D_train, "Train"), (D_valid, "Valid")],
feval=eval_error_metric, # custom evaluation metric
num_boost_round=100,
callbacks=[early_stop],
)
*****************************************************************************
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
*****************************************************************************
@ -414,15 +483,3 @@ References:
#. https://github.com/dask/dask/issues/6833
#. https://stackoverflow.com/questions/45941528/how-to-efficiently-send-a-large-numpy-array-to-the-cluster-with-dask-array
***********
Limitations
***********
Basic functionality including model training and generating classification and regression predictions
have been implemented. However, there are still some other limitations we haven't
addressed yet:
- Label encoding for the ``DaskXGBClassifier`` classifier may not be supported. So users need
to encode their training labels into discrete values first.
- Ranking is not yet supported.

View File

@ -940,16 +940,14 @@ def _can_output_df(data: _DaskCollection, output_shape: Tuple) -> bool:
async def _direct_predict_impl(
client: "distributed.Client",
mapped_predict: Callable,
booster: Booster,
booster: "distributed.Future",
data: _DaskCollection,
base_margin: Optional[_DaskCollection],
output_shape: Tuple[int, ...],
meta: Dict[int, str],
) -> _DaskCollection:
columns = list(meta.keys())
booster_f = await client.scatter(data=booster, broadcast=True)
if _can_output_df(data, output_shape):
if base_margin is not None and isinstance(base_margin, da.Array):
base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe()
@ -957,7 +955,7 @@ async def _direct_predict_impl(
base_margin_df = base_margin
predictions = dd.map_partitions(
mapped_predict,
booster_f,
booster,
data,
True,
columns,
@ -984,7 +982,7 @@ async def _direct_predict_impl(
new_axis = [i + 2 for i in range(len(output_shape) - 2)]
predictions = da.map_blocks(
mapped_predict,
booster_f,
booster,
data,
False,
columns,
@ -997,7 +995,10 @@ async def _direct_predict_impl(
def _infer_predict_output(
booster: Booster, data: _DaskCollection, inplace: bool, **kwargs: Any
booster: Booster,
data: Union[DaskDMatrix, _DaskCollection],
inplace: bool,
**kwargs: Any
) -> Tuple[Tuple[int, ...], Dict[int, str]]:
"""Create a dummy test sample to infer output shape for prediction."""
if isinstance(data, DaskDMatrix):
@ -1021,11 +1022,29 @@ def _infer_predict_output(
return test_predt.shape, meta
async def _get_model_future(
client: "distributed.Client", model: Union[Booster, Dict, "distributed.Future"]
) -> "distributed.Future":
if isinstance(model, Booster):
booster = await client.scatter(model, broadcast=True)
elif isinstance(model, dict):
booster = await client.scatter(model["booster"])
elif isinstance(model, distributed.Future):
booster = model
if booster.type is not Booster:
raise TypeError(
f"Underlying type of model future should be `Booster`, got {booster.type}"
)
else:
raise TypeError(_expect([Booster, dict, distributed.Future], type(model)))
return booster
# pylint: disable=too-many-statements
async def _predict_async(
client: "distributed.Client",
global_config: Dict[str, Any],
model: Union[Booster, Dict],
model: Union[Booster, Dict, "distributed.Future"],
data: _DaskCollection,
output_margin: bool,
missing: float,
@ -1035,12 +1054,7 @@ async def _predict_async(
pred_interactions: bool,
validate_features: bool,
) -> _DaskCollection:
if isinstance(model, Booster):
_booster = model
elif isinstance(model, dict):
_booster = model["booster"]
else:
raise TypeError(_expect([Booster, dict], type(model)))
_booster = await _get_model_future(client, model)
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data)))
@ -1070,7 +1084,7 @@ async def _predict_async(
# Predict on dask collection directly.
if isinstance(data, (da.Array, dd.DataFrame)):
_output_shape, meta = _infer_predict_output(
_booster,
await _booster.result(),
data,
inplace=False,
output_margin=output_margin,
@ -1081,10 +1095,11 @@ async def _predict_async(
validate_features=False,
)
return await _direct_predict_impl(
client, mapped_predict, _booster, data, None, _output_shape, meta
mapped_predict, _booster, data, None, _output_shape, meta
)
output_shape, _ = _infer_predict_output(
booster=_booster,
booster=await _booster.result(),
data=data,
inplace=False,
output_margin=output_margin,
@ -1108,11 +1123,9 @@ async def _predict_async(
for i, blob in enumerate(part[1:]):
if meta_names[i] == "base_margin":
base_margin = blob
worker = distributed.get_worker()
with config.config_context(**global_config):
m = DMatrix(
data,
nthread=worker.nthreads,
missing=missing,
base_margin=base_margin,
feature_names=feature_names,
@ -1148,9 +1161,8 @@ async def _predict_async(
all_shapes = [shape for part, shape, order in parts_with_order]
futures = []
booster_f = await client.scatter(data=_booster, broadcast=True)
for part in all_parts:
f = client.submit(dispatched_predict, booster_f, part)
f = client.submit(dispatched_predict, _booster, part)
futures.append(f)
# Constructing a dask array from list of numpy arrays
@ -1168,7 +1180,7 @@ async def _predict_async(
def predict( # pylint: disable=unused-argument
client: "distributed.Client",
model: Union[TrainReturnT, Booster],
model: Union[TrainReturnT, Booster, "distributed.Future"],
data: Union[DaskDMatrix, _DaskCollection],
output_margin: bool = False,
missing: float = numpy.nan,
@ -1194,7 +1206,8 @@ def predict( # pylint: disable=unused-argument
Specify the dask client used for training. Use default client
returned from dask if it's set to None.
model:
The trained model.
The trained model. It can be a distributed.Future so user can
pre-scatter it onto all workers.
data:
Input data used for prediction. When input is a dataframe object,
prediction output is a series.
@ -1221,19 +1234,14 @@ def predict( # pylint: disable=unused-argument
async def _inplace_predict_async(
client: "distributed.Client",
global_config: Dict[str, Any],
model: Union[Booster, Dict],
model: Union[Booster, Dict, "distributed.Future"],
data: _DaskCollection,
iteration_range: Tuple[int, int] = (0, 0),
predict_type: str = 'value',
missing: float = numpy.nan
) -> _DaskCollection:
client = _xgb_get_client(client)
if isinstance(model, Booster):
booster = model
elif isinstance(model, dict):
booster = model['booster']
else:
raise TypeError(_expect([Booster, dict], type(model)))
booster = await _get_model_future(client, model)
if not isinstance(data, (da.Array, dd.DataFrame)):
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
@ -1261,16 +1269,20 @@ async def _inplace_predict_async(
return prediction
shape, meta = _infer_predict_output(
booster, data, True, predict_type=predict_type, iteration_range=iteration_range
await booster.result(),
data,
True,
predict_type=predict_type,
iteration_range=iteration_range
)
return await _direct_predict_impl(
client, mapped_predict, booster, data, None, shape, meta
mapped_predict, booster, data, None, shape, meta
)
def inplace_predict( # pylint: disable=unused-argument
client: "distributed.Client",
model: Union[TrainReturnT, Booster],
model: Union[TrainReturnT, Booster, "distributed.Future"],
data: _DaskCollection,
iteration_range: Tuple[int, int] = (0, 0),
predict_type: str = 'value',
@ -1286,7 +1298,8 @@ def inplace_predict( # pylint: disable=unused-argument
Specify the dask client used for training. Use default client
returned from dask if it's set to None.
model:
The trained model.
The trained model. It can be a distributed.Future so user can
pre-scatter it onto all workers.
iteration_range:
Specify the range of trees used for prediction.
predict_type:

View File

@ -535,7 +535,7 @@ class XGBModel(XGBModelBase):
json.dumps({k: v})
meta[k] = v
except TypeError:
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.')
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.', UserWarning)
meta['_estimator_type'] = self._get_type()
meta_str = json.dumps(meta)
self.get_booster().set_attr(scikit_learn=meta_str)

View File

@ -608,29 +608,31 @@ def test_with_asyncio() -> None:
asyncio.run(run_dask_classifier_asyncio(address))
def test_predict() -> None:
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
def test_predict(client: "Client") -> None:
X, y, _ = generate_array()
dtrain = DaskDMatrix(client, X, y)
booster = xgb.dask.train(
client, {}, dtrain, num_boost_round=2)['booster']
booster = xgb.dask.train(client, {}, dtrain, num_boost_round=2)["booster"]
pred = xgb.dask.predict(client, model=booster, data=dtrain)
assert pred.ndim == 1
assert pred.shape[0] == kRows
predt_0 = xgb.dask.predict(client, model=booster, data=dtrain)
assert predt_0.ndim == 1
assert predt_0.shape[0] == kRows
margin = xgb.dask.predict(client, model=booster, data=dtrain,
output_margin=True)
margin = xgb.dask.predict(client, model=booster, data=dtrain, output_margin=True)
assert margin.ndim == 1
assert margin.shape[0] == kRows
shap = xgb.dask.predict(client, model=booster, data=dtrain,
pred_contribs=True)
shap = xgb.dask.predict(client, model=booster, data=dtrain, pred_contribs=True)
assert shap.ndim == 2
assert shap.shape[0] == kRows
assert shap.shape[1] == kCols + 1
booster_f = client.scatter(booster, broadcast=True)
predt_1 = xgb.dask.predict(client, booster_f, X).compute()
predt_2 = xgb.dask.inplace_predict(client, booster_f, X).compute()
np.testing.assert_allclose(predt_0, predt_1)
np.testing.assert_allclose(predt_0, predt_2)
def test_predict_with_meta(client: "Client") -> None:
X, y, w = generate_array(with_weights=True)
@ -1034,7 +1036,7 @@ class TestWithDask:
rows = X.shape[0]
cols = X.shape[1]
def assert_shape(shape):
def assert_shape(shape: Tuple[int, ...]) -> None:
assert shape[0] == rows
if "num_class" in params.keys():
assert shape[1] == params["num_class"]