[dask] Accept Future of model for prediction. (#6650)
This PR changes predict and inplace_predict to accept a Future of model, to avoid sending models to workers repeatably. * Document is updated to reflect functionality additions in recent changes.
This commit is contained in:
@@ -940,16 +940,14 @@ def _can_output_df(data: _DaskCollection, output_shape: Tuple) -> bool:
|
||||
|
||||
|
||||
async def _direct_predict_impl(
|
||||
client: "distributed.Client",
|
||||
mapped_predict: Callable,
|
||||
booster: Booster,
|
||||
booster: "distributed.Future",
|
||||
data: _DaskCollection,
|
||||
base_margin: Optional[_DaskCollection],
|
||||
output_shape: Tuple[int, ...],
|
||||
meta: Dict[int, str],
|
||||
) -> _DaskCollection:
|
||||
columns = list(meta.keys())
|
||||
booster_f = await client.scatter(data=booster, broadcast=True)
|
||||
if _can_output_df(data, output_shape):
|
||||
if base_margin is not None and isinstance(base_margin, da.Array):
|
||||
base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe()
|
||||
@@ -957,7 +955,7 @@ async def _direct_predict_impl(
|
||||
base_margin_df = base_margin
|
||||
predictions = dd.map_partitions(
|
||||
mapped_predict,
|
||||
booster_f,
|
||||
booster,
|
||||
data,
|
||||
True,
|
||||
columns,
|
||||
@@ -984,7 +982,7 @@ async def _direct_predict_impl(
|
||||
new_axis = [i + 2 for i in range(len(output_shape) - 2)]
|
||||
predictions = da.map_blocks(
|
||||
mapped_predict,
|
||||
booster_f,
|
||||
booster,
|
||||
data,
|
||||
False,
|
||||
columns,
|
||||
@@ -997,7 +995,10 @@ async def _direct_predict_impl(
|
||||
|
||||
|
||||
def _infer_predict_output(
|
||||
booster: Booster, data: _DaskCollection, inplace: bool, **kwargs: Any
|
||||
booster: Booster,
|
||||
data: Union[DaskDMatrix, _DaskCollection],
|
||||
inplace: bool,
|
||||
**kwargs: Any
|
||||
) -> Tuple[Tuple[int, ...], Dict[int, str]]:
|
||||
"""Create a dummy test sample to infer output shape for prediction."""
|
||||
if isinstance(data, DaskDMatrix):
|
||||
@@ -1021,11 +1022,29 @@ def _infer_predict_output(
|
||||
return test_predt.shape, meta
|
||||
|
||||
|
||||
async def _get_model_future(
|
||||
client: "distributed.Client", model: Union[Booster, Dict, "distributed.Future"]
|
||||
) -> "distributed.Future":
|
||||
if isinstance(model, Booster):
|
||||
booster = await client.scatter(model, broadcast=True)
|
||||
elif isinstance(model, dict):
|
||||
booster = await client.scatter(model["booster"])
|
||||
elif isinstance(model, distributed.Future):
|
||||
booster = model
|
||||
if booster.type is not Booster:
|
||||
raise TypeError(
|
||||
f"Underlying type of model future should be `Booster`, got {booster.type}"
|
||||
)
|
||||
else:
|
||||
raise TypeError(_expect([Booster, dict, distributed.Future], type(model)))
|
||||
return booster
|
||||
|
||||
|
||||
# pylint: disable=too-many-statements
|
||||
async def _predict_async(
|
||||
client: "distributed.Client",
|
||||
global_config: Dict[str, Any],
|
||||
model: Union[Booster, Dict],
|
||||
model: Union[Booster, Dict, "distributed.Future"],
|
||||
data: _DaskCollection,
|
||||
output_margin: bool,
|
||||
missing: float,
|
||||
@@ -1035,12 +1054,7 @@ async def _predict_async(
|
||||
pred_interactions: bool,
|
||||
validate_features: bool,
|
||||
) -> _DaskCollection:
|
||||
if isinstance(model, Booster):
|
||||
_booster = model
|
||||
elif isinstance(model, dict):
|
||||
_booster = model["booster"]
|
||||
else:
|
||||
raise TypeError(_expect([Booster, dict], type(model)))
|
||||
_booster = await _get_model_future(client, model)
|
||||
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
|
||||
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data)))
|
||||
|
||||
@@ -1070,7 +1084,7 @@ async def _predict_async(
|
||||
# Predict on dask collection directly.
|
||||
if isinstance(data, (da.Array, dd.DataFrame)):
|
||||
_output_shape, meta = _infer_predict_output(
|
||||
_booster,
|
||||
await _booster.result(),
|
||||
data,
|
||||
inplace=False,
|
||||
output_margin=output_margin,
|
||||
@@ -1081,10 +1095,11 @@ async def _predict_async(
|
||||
validate_features=False,
|
||||
)
|
||||
return await _direct_predict_impl(
|
||||
client, mapped_predict, _booster, data, None, _output_shape, meta
|
||||
mapped_predict, _booster, data, None, _output_shape, meta
|
||||
)
|
||||
|
||||
output_shape, _ = _infer_predict_output(
|
||||
booster=_booster,
|
||||
booster=await _booster.result(),
|
||||
data=data,
|
||||
inplace=False,
|
||||
output_margin=output_margin,
|
||||
@@ -1108,11 +1123,9 @@ async def _predict_async(
|
||||
for i, blob in enumerate(part[1:]):
|
||||
if meta_names[i] == "base_margin":
|
||||
base_margin = blob
|
||||
worker = distributed.get_worker()
|
||||
with config.config_context(**global_config):
|
||||
m = DMatrix(
|
||||
data,
|
||||
nthread=worker.nthreads,
|
||||
missing=missing,
|
||||
base_margin=base_margin,
|
||||
feature_names=feature_names,
|
||||
@@ -1148,9 +1161,8 @@ async def _predict_async(
|
||||
all_shapes = [shape for part, shape, order in parts_with_order]
|
||||
|
||||
futures = []
|
||||
booster_f = await client.scatter(data=_booster, broadcast=True)
|
||||
for part in all_parts:
|
||||
f = client.submit(dispatched_predict, booster_f, part)
|
||||
f = client.submit(dispatched_predict, _booster, part)
|
||||
futures.append(f)
|
||||
|
||||
# Constructing a dask array from list of numpy arrays
|
||||
@@ -1168,7 +1180,7 @@ async def _predict_async(
|
||||
|
||||
def predict( # pylint: disable=unused-argument
|
||||
client: "distributed.Client",
|
||||
model: Union[TrainReturnT, Booster],
|
||||
model: Union[TrainReturnT, Booster, "distributed.Future"],
|
||||
data: Union[DaskDMatrix, _DaskCollection],
|
||||
output_margin: bool = False,
|
||||
missing: float = numpy.nan,
|
||||
@@ -1194,7 +1206,8 @@ def predict( # pylint: disable=unused-argument
|
||||
Specify the dask client used for training. Use default client
|
||||
returned from dask if it's set to None.
|
||||
model:
|
||||
The trained model.
|
||||
The trained model. It can be a distributed.Future so user can
|
||||
pre-scatter it onto all workers.
|
||||
data:
|
||||
Input data used for prediction. When input is a dataframe object,
|
||||
prediction output is a series.
|
||||
@@ -1221,19 +1234,14 @@ def predict( # pylint: disable=unused-argument
|
||||
async def _inplace_predict_async(
|
||||
client: "distributed.Client",
|
||||
global_config: Dict[str, Any],
|
||||
model: Union[Booster, Dict],
|
||||
model: Union[Booster, Dict, "distributed.Future"],
|
||||
data: _DaskCollection,
|
||||
iteration_range: Tuple[int, int] = (0, 0),
|
||||
predict_type: str = 'value',
|
||||
missing: float = numpy.nan
|
||||
) -> _DaskCollection:
|
||||
client = _xgb_get_client(client)
|
||||
if isinstance(model, Booster):
|
||||
booster = model
|
||||
elif isinstance(model, dict):
|
||||
booster = model['booster']
|
||||
else:
|
||||
raise TypeError(_expect([Booster, dict], type(model)))
|
||||
booster = await _get_model_future(client, model)
|
||||
if not isinstance(data, (da.Array, dd.DataFrame)):
|
||||
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
|
||||
|
||||
@@ -1261,16 +1269,20 @@ async def _inplace_predict_async(
|
||||
return prediction
|
||||
|
||||
shape, meta = _infer_predict_output(
|
||||
booster, data, True, predict_type=predict_type, iteration_range=iteration_range
|
||||
await booster.result(),
|
||||
data,
|
||||
True,
|
||||
predict_type=predict_type,
|
||||
iteration_range=iteration_range
|
||||
)
|
||||
return await _direct_predict_impl(
|
||||
client, mapped_predict, booster, data, None, shape, meta
|
||||
mapped_predict, booster, data, None, shape, meta
|
||||
)
|
||||
|
||||
|
||||
def inplace_predict( # pylint: disable=unused-argument
|
||||
client: "distributed.Client",
|
||||
model: Union[TrainReturnT, Booster],
|
||||
model: Union[TrainReturnT, Booster, "distributed.Future"],
|
||||
data: _DaskCollection,
|
||||
iteration_range: Tuple[int, int] = (0, 0),
|
||||
predict_type: str = 'value',
|
||||
@@ -1286,7 +1298,8 @@ def inplace_predict( # pylint: disable=unused-argument
|
||||
Specify the dask client used for training. Use default client
|
||||
returned from dask if it's set to None.
|
||||
model:
|
||||
The trained model.
|
||||
The trained model. It can be a distributed.Future so user can
|
||||
pre-scatter it onto all workers.
|
||||
iteration_range:
|
||||
Specify the range of trees used for prediction.
|
||||
predict_type:
|
||||
|
||||
@@ -535,7 +535,7 @@ class XGBModel(XGBModelBase):
|
||||
json.dumps({k: v})
|
||||
meta[k] = v
|
||||
except TypeError:
|
||||
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.')
|
||||
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.', UserWarning)
|
||||
meta['_estimator_type'] = self._get_type()
|
||||
meta_str = json.dumps(meta)
|
||||
self.get_booster().set_attr(scikit_learn=meta_str)
|
||||
|
||||
Reference in New Issue
Block a user