[dask] Accept Future of model for prediction. (#6650)

This PR changes predict and inplace_predict to accept a Future of model, to avoid sending models to workers repeatably.

* Document is updated to reflect functionality additions in recent changes.
This commit is contained in:
Jiaming Yuan
2021-02-02 08:45:52 +08:00
committed by GitHub
parent a9ec0ea6da
commit 87ab1ad607
4 changed files with 150 additions and 78 deletions

View File

@@ -940,16 +940,14 @@ def _can_output_df(data: _DaskCollection, output_shape: Tuple) -> bool:
async def _direct_predict_impl(
client: "distributed.Client",
mapped_predict: Callable,
booster: Booster,
booster: "distributed.Future",
data: _DaskCollection,
base_margin: Optional[_DaskCollection],
output_shape: Tuple[int, ...],
meta: Dict[int, str],
) -> _DaskCollection:
columns = list(meta.keys())
booster_f = await client.scatter(data=booster, broadcast=True)
if _can_output_df(data, output_shape):
if base_margin is not None and isinstance(base_margin, da.Array):
base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe()
@@ -957,7 +955,7 @@ async def _direct_predict_impl(
base_margin_df = base_margin
predictions = dd.map_partitions(
mapped_predict,
booster_f,
booster,
data,
True,
columns,
@@ -984,7 +982,7 @@ async def _direct_predict_impl(
new_axis = [i + 2 for i in range(len(output_shape) - 2)]
predictions = da.map_blocks(
mapped_predict,
booster_f,
booster,
data,
False,
columns,
@@ -997,7 +995,10 @@ async def _direct_predict_impl(
def _infer_predict_output(
booster: Booster, data: _DaskCollection, inplace: bool, **kwargs: Any
booster: Booster,
data: Union[DaskDMatrix, _DaskCollection],
inplace: bool,
**kwargs: Any
) -> Tuple[Tuple[int, ...], Dict[int, str]]:
"""Create a dummy test sample to infer output shape for prediction."""
if isinstance(data, DaskDMatrix):
@@ -1021,11 +1022,29 @@ def _infer_predict_output(
return test_predt.shape, meta
async def _get_model_future(
client: "distributed.Client", model: Union[Booster, Dict, "distributed.Future"]
) -> "distributed.Future":
if isinstance(model, Booster):
booster = await client.scatter(model, broadcast=True)
elif isinstance(model, dict):
booster = await client.scatter(model["booster"])
elif isinstance(model, distributed.Future):
booster = model
if booster.type is not Booster:
raise TypeError(
f"Underlying type of model future should be `Booster`, got {booster.type}"
)
else:
raise TypeError(_expect([Booster, dict, distributed.Future], type(model)))
return booster
# pylint: disable=too-many-statements
async def _predict_async(
client: "distributed.Client",
global_config: Dict[str, Any],
model: Union[Booster, Dict],
model: Union[Booster, Dict, "distributed.Future"],
data: _DaskCollection,
output_margin: bool,
missing: float,
@@ -1035,12 +1054,7 @@ async def _predict_async(
pred_interactions: bool,
validate_features: bool,
) -> _DaskCollection:
if isinstance(model, Booster):
_booster = model
elif isinstance(model, dict):
_booster = model["booster"]
else:
raise TypeError(_expect([Booster, dict], type(model)))
_booster = await _get_model_future(client, model)
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data)))
@@ -1070,7 +1084,7 @@ async def _predict_async(
# Predict on dask collection directly.
if isinstance(data, (da.Array, dd.DataFrame)):
_output_shape, meta = _infer_predict_output(
_booster,
await _booster.result(),
data,
inplace=False,
output_margin=output_margin,
@@ -1081,10 +1095,11 @@ async def _predict_async(
validate_features=False,
)
return await _direct_predict_impl(
client, mapped_predict, _booster, data, None, _output_shape, meta
mapped_predict, _booster, data, None, _output_shape, meta
)
output_shape, _ = _infer_predict_output(
booster=_booster,
booster=await _booster.result(),
data=data,
inplace=False,
output_margin=output_margin,
@@ -1108,11 +1123,9 @@ async def _predict_async(
for i, blob in enumerate(part[1:]):
if meta_names[i] == "base_margin":
base_margin = blob
worker = distributed.get_worker()
with config.config_context(**global_config):
m = DMatrix(
data,
nthread=worker.nthreads,
missing=missing,
base_margin=base_margin,
feature_names=feature_names,
@@ -1148,9 +1161,8 @@ async def _predict_async(
all_shapes = [shape for part, shape, order in parts_with_order]
futures = []
booster_f = await client.scatter(data=_booster, broadcast=True)
for part in all_parts:
f = client.submit(dispatched_predict, booster_f, part)
f = client.submit(dispatched_predict, _booster, part)
futures.append(f)
# Constructing a dask array from list of numpy arrays
@@ -1168,7 +1180,7 @@ async def _predict_async(
def predict( # pylint: disable=unused-argument
client: "distributed.Client",
model: Union[TrainReturnT, Booster],
model: Union[TrainReturnT, Booster, "distributed.Future"],
data: Union[DaskDMatrix, _DaskCollection],
output_margin: bool = False,
missing: float = numpy.nan,
@@ -1194,7 +1206,8 @@ def predict( # pylint: disable=unused-argument
Specify the dask client used for training. Use default client
returned from dask if it's set to None.
model:
The trained model.
The trained model. It can be a distributed.Future so user can
pre-scatter it onto all workers.
data:
Input data used for prediction. When input is a dataframe object,
prediction output is a series.
@@ -1221,19 +1234,14 @@ def predict( # pylint: disable=unused-argument
async def _inplace_predict_async(
client: "distributed.Client",
global_config: Dict[str, Any],
model: Union[Booster, Dict],
model: Union[Booster, Dict, "distributed.Future"],
data: _DaskCollection,
iteration_range: Tuple[int, int] = (0, 0),
predict_type: str = 'value',
missing: float = numpy.nan
) -> _DaskCollection:
client = _xgb_get_client(client)
if isinstance(model, Booster):
booster = model
elif isinstance(model, dict):
booster = model['booster']
else:
raise TypeError(_expect([Booster, dict], type(model)))
booster = await _get_model_future(client, model)
if not isinstance(data, (da.Array, dd.DataFrame)):
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
@@ -1261,16 +1269,20 @@ async def _inplace_predict_async(
return prediction
shape, meta = _infer_predict_output(
booster, data, True, predict_type=predict_type, iteration_range=iteration_range
await booster.result(),
data,
True,
predict_type=predict_type,
iteration_range=iteration_range
)
return await _direct_predict_impl(
client, mapped_predict, booster, data, None, shape, meta
mapped_predict, booster, data, None, shape, meta
)
def inplace_predict( # pylint: disable=unused-argument
client: "distributed.Client",
model: Union[TrainReturnT, Booster],
model: Union[TrainReturnT, Booster, "distributed.Future"],
data: _DaskCollection,
iteration_range: Tuple[int, int] = (0, 0),
predict_type: str = 'value',
@@ -1286,7 +1298,8 @@ def inplace_predict( # pylint: disable=unused-argument
Specify the dask client used for training. Use default client
returned from dask if it's set to None.
model:
The trained model.
The trained model. It can be a distributed.Future so user can
pre-scatter it onto all workers.
iteration_range:
Specify the range of trees used for prediction.
predict_type:

View File

@@ -535,7 +535,7 @@ class XGBModel(XGBModelBase):
json.dumps({k: v})
meta[k] = v
except TypeError:
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.')
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.', UserWarning)
meta['_estimator_type'] = self._get_type()
meta_str = json.dumps(meta)
self.get_booster().set_attr(scikit_learn=meta_str)