[dask] Add a 1 line sample to infer output shape. (#6645)
* [dask] Use a 1 line sample to infer output shape. This is for inferring shape with direct prediction (without DaskDMatrix). There are a few things that requires known output shape before carrying out actual prediction, including dask meta data, output dataframe columns. * Infer output shape based on local prediction. * Remove set param in predict function as it's not thread safe nor necessary as we now let dask to decide the parallelism. * Simplify prediction on `DaskDMatrix`.
This commit is contained in:
@@ -112,14 +112,15 @@ def _start_tracker(n_workers: int) -> Dict[str, Any]:
|
||||
|
||||
def _assert_dask_support() -> None:
|
||||
try:
|
||||
import dask # pylint: disable=W0621,W0611
|
||||
import dask # pylint: disable=W0621,W0611
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
'Dask needs to be installed in order to use this module') from e
|
||||
"Dask needs to be installed in order to use this module"
|
||||
) from e
|
||||
|
||||
if platform.system() == 'Windows':
|
||||
msg = 'Windows is not officially supported for dask/xgboost,'
|
||||
msg += ' contribution are welcomed.'
|
||||
if platform.system() == "Windows":
|
||||
msg = "Windows is not officially supported for dask/xgboost,"
|
||||
msg += " contribution are welcomed."
|
||||
LOGGER.warning(msg)
|
||||
|
||||
|
||||
@@ -252,6 +253,7 @@ class DaskDMatrix:
|
||||
if not isinstance(label, (dd.DataFrame, da.Array, dd.Series, type(None))):
|
||||
raise TypeError(_expect((dd.DataFrame, da.Array, dd.Series), type(label)))
|
||||
|
||||
self._n_cols = data.shape[1]
|
||||
self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list)
|
||||
self.is_quantile: bool = False
|
||||
|
||||
@@ -403,6 +405,9 @@ class DaskDMatrix:
|
||||
'parts': self.worker_map.get(worker_addr, None),
|
||||
'is_quantile': self.is_quantile}
|
||||
|
||||
def num_col(self) -> int:
|
||||
return self._n_cols
|
||||
|
||||
|
||||
_DataParts = List[Tuple[Any, Optional[Any], Optional[Any], Optional[Any], Optional[Any],
|
||||
Optional[Any], Optional[Any]]]
|
||||
@@ -930,27 +935,90 @@ def train(
|
||||
callbacks=callbacks)
|
||||
|
||||
|
||||
def _can_output_df(data: _DaskCollection, output_shape: Tuple) -> bool:
|
||||
return isinstance(data, dd.DataFrame) and len(output_shape) <= 2
|
||||
|
||||
|
||||
async def _direct_predict_impl(
|
||||
client: "distributed.Client",
|
||||
mapped_predict: Callable,
|
||||
booster: Booster,
|
||||
data: _DaskCollection,
|
||||
predict_fn: Callable
|
||||
base_margin: Optional[_DaskCollection],
|
||||
output_shape: Tuple[int, ...],
|
||||
meta: Dict[int, str],
|
||||
) -> _DaskCollection:
|
||||
if isinstance(data, da.Array):
|
||||
predictions = await client.submit(
|
||||
da.map_blocks,
|
||||
predict_fn, data, False, drop_axis=1,
|
||||
dtype=numpy.float32
|
||||
).result()
|
||||
return predictions
|
||||
if isinstance(data, dd.DataFrame):
|
||||
predictions = await client.submit(
|
||||
dd.map_partitions,
|
||||
predict_fn, data, True,
|
||||
meta=dd.utils.make_meta({'prediction': 'f4'})
|
||||
).result()
|
||||
return predictions.iloc[:, 0]
|
||||
raise TypeError('data of type: ' + str(type(data)) +
|
||||
' is not supported by direct prediction')
|
||||
columns = list(meta.keys())
|
||||
booster_f = await client.scatter(data=booster, broadcast=True)
|
||||
if _can_output_df(data, output_shape):
|
||||
if base_margin is not None and isinstance(base_margin, da.Array):
|
||||
base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe()
|
||||
else:
|
||||
base_margin_df = base_margin
|
||||
predictions = dd.map_partitions(
|
||||
mapped_predict,
|
||||
booster_f,
|
||||
data,
|
||||
True,
|
||||
columns,
|
||||
base_margin_df,
|
||||
meta=dd.utils.make_meta(meta),
|
||||
)
|
||||
# classification can return a dataframe, drop 1 dim when it's reg/binary
|
||||
if len(output_shape) == 1:
|
||||
predictions = predictions.iloc[:, 0]
|
||||
else:
|
||||
if base_margin is not None and isinstance(
|
||||
base_margin, (dd.Series, dd.DataFrame)
|
||||
):
|
||||
base_margin_array: Optional[da.Array] = base_margin.to_dask_array()
|
||||
else:
|
||||
base_margin_array = base_margin
|
||||
# Input data is 2-dim array, output can be 1(reg, binary)/2(multi-class,
|
||||
# contrib)/3(contrib)/4(interaction) dims.
|
||||
if len(output_shape) == 1:
|
||||
drop_axis: Union[int, List[int]] = [1] # drop from 2 to 1 dim.
|
||||
new_axis: Union[int, List[int]] = []
|
||||
else:
|
||||
drop_axis = []
|
||||
new_axis = [i + 2 for i in range(len(output_shape) - 2)]
|
||||
predictions = da.map_blocks(
|
||||
mapped_predict,
|
||||
booster_f,
|
||||
data,
|
||||
False,
|
||||
columns,
|
||||
base_margin_array,
|
||||
drop_axis=drop_axis,
|
||||
new_axis=new_axis,
|
||||
dtype=numpy.float32,
|
||||
)
|
||||
return predictions
|
||||
|
||||
|
||||
def _infer_predict_output(
|
||||
booster: Booster, data: _DaskCollection, inplace: bool, **kwargs: Any
|
||||
) -> Tuple[Tuple[int, ...], Dict[int, str]]:
|
||||
"""Create a dummy test sample to infer output shape for prediction."""
|
||||
if isinstance(data, DaskDMatrix):
|
||||
features = data.num_col()
|
||||
else:
|
||||
features = data.shape[1]
|
||||
rng = numpy.random.RandomState(1994)
|
||||
test_sample = rng.randn(1, features)
|
||||
if inplace:
|
||||
# clear the state to avoid gpu_id, gpu_predictor
|
||||
booster = Booster(model_file=booster.save_raw())
|
||||
test_predt = booster.inplace_predict(test_sample, **kwargs)
|
||||
else:
|
||||
m = DMatrix(test_sample)
|
||||
test_predt = booster.predict(m, **kwargs)
|
||||
n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1
|
||||
meta: Dict[int, str] = {}
|
||||
if _can_output_df(data, test_predt.shape):
|
||||
for i in range(n_columns):
|
||||
meta[i] = "f4"
|
||||
return test_predt.shape, meta
|
||||
|
||||
|
||||
# pylint: disable=too-many-statements
|
||||
@@ -968,19 +1036,19 @@ async def _predict_async(
|
||||
validate_features: bool,
|
||||
) -> _DaskCollection:
|
||||
if isinstance(model, Booster):
|
||||
booster = model
|
||||
_booster = model
|
||||
elif isinstance(model, dict):
|
||||
booster = model["booster"]
|
||||
_booster = model["booster"]
|
||||
else:
|
||||
raise TypeError(_expect([Booster, dict], type(model)))
|
||||
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
|
||||
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data)))
|
||||
|
||||
def mapped_predict(partition: Any, is_df: bool) -> Any:
|
||||
worker = distributed.get_worker()
|
||||
def mapped_predict(
|
||||
booster: Booster, partition: Any, is_df: bool, columns: List[int], _: Any
|
||||
) -> Any:
|
||||
with config.config_context(**global_config):
|
||||
booster.set_param({"nthread": worker.nthreads})
|
||||
m = DMatrix(data=partition, missing=missing, nthread=worker.nthreads)
|
||||
m = DMatrix(data=partition, missing=missing)
|
||||
predt = booster.predict(
|
||||
data=m,
|
||||
output_margin=output_margin,
|
||||
@@ -990,167 +1058,115 @@ async def _predict_async(
|
||||
pred_interactions=pred_interactions,
|
||||
validate_features=validate_features,
|
||||
)
|
||||
if is_df:
|
||||
if is_df and len(predt.shape) <= 2:
|
||||
if lazy_isinstance(partition, "cudf", "core.dataframe.DataFrame"):
|
||||
import cudf
|
||||
predt = cudf.DataFrame(predt, columns=["prediction"])
|
||||
|
||||
predt = cudf.DataFrame(predt, columns=columns)
|
||||
else:
|
||||
predt = DataFrame(predt, columns=["prediction"])
|
||||
predt = DataFrame(predt, columns=columns)
|
||||
return predt
|
||||
|
||||
# Predict on dask collection directly.
|
||||
if isinstance(data, (da.Array, dd.DataFrame)):
|
||||
return await _direct_predict_impl(client, data, mapped_predict)
|
||||
|
||||
_output_shape, meta = _infer_predict_output(
|
||||
_booster,
|
||||
data,
|
||||
inplace=False,
|
||||
output_margin=output_margin,
|
||||
pred_leaf=pred_leaf,
|
||||
pred_contribs=pred_contribs,
|
||||
approx_contribs=approx_contribs,
|
||||
pred_interactions=pred_interactions,
|
||||
validate_features=False,
|
||||
)
|
||||
return await _direct_predict_impl(
|
||||
client, mapped_predict, _booster, data, None, _output_shape, meta
|
||||
)
|
||||
output_shape, _ = _infer_predict_output(
|
||||
booster=_booster,
|
||||
data=data,
|
||||
inplace=False,
|
||||
output_margin=output_margin,
|
||||
pred_leaf=pred_leaf,
|
||||
pred_contribs=pred_contribs,
|
||||
approx_contribs=approx_contribs,
|
||||
pred_interactions=pred_interactions,
|
||||
validate_features=False,
|
||||
)
|
||||
# Prediction on dask DMatrix.
|
||||
worker_map = data.worker_map
|
||||
partition_order = data.partition_order
|
||||
feature_names = data.feature_names
|
||||
feature_types = data.feature_types
|
||||
missing = data.missing
|
||||
meta_names = data.meta_names
|
||||
|
||||
def dispatched_predict(
|
||||
worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts
|
||||
) -> List[Tuple[List[Union["dask.delayed.Delayed", int]], int]]:
|
||||
"""Perform prediction on each worker."""
|
||||
LOGGER.debug("Predicting on %d", worker_id)
|
||||
def dispatched_predict(booster: Booster, part: Any) -> numpy.ndarray:
|
||||
data = part[0]
|
||||
assert isinstance(part, tuple), type(part)
|
||||
base_margin = None
|
||||
for i, blob in enumerate(part[1:]):
|
||||
if meta_names[i] == "base_margin":
|
||||
base_margin = blob
|
||||
worker = distributed.get_worker()
|
||||
with config.config_context(**global_config):
|
||||
worker = distributed.get_worker()
|
||||
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
|
||||
predictions = []
|
||||
|
||||
booster.set_param({"nthread": worker.nthreads})
|
||||
for i, parts in enumerate(list_of_parts):
|
||||
(data, _, _, base_margin, _, _, _) = parts
|
||||
order = list_of_orders[i]
|
||||
local_part = DMatrix(
|
||||
data,
|
||||
base_margin=base_margin,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
missing=missing,
|
||||
nthread=worker.nthreads,
|
||||
)
|
||||
predt = booster.predict(
|
||||
data=local_part,
|
||||
output_margin=output_margin,
|
||||
pred_leaf=pred_leaf,
|
||||
pred_contribs=pred_contribs,
|
||||
approx_contribs=approx_contribs,
|
||||
pred_interactions=pred_interactions,
|
||||
validate_features=validate_features,
|
||||
)
|
||||
if pred_contribs and predt.size != local_part.num_row():
|
||||
assert len(predt.shape) in (2, 3)
|
||||
if len(predt.shape) == 2:
|
||||
groups = 1
|
||||
columns = predt.shape[1]
|
||||
else:
|
||||
groups = predt.shape[1]
|
||||
columns = predt.shape[2]
|
||||
# pylint: disable=no-member
|
||||
ret = (
|
||||
[dask.delayed(predt), groups, columns],
|
||||
order,
|
||||
)
|
||||
elif pred_interactions and predt.size != local_part.num_row():
|
||||
assert len(predt.shape) in (3, 4)
|
||||
if len(predt.shape) == 3:
|
||||
groups = 1
|
||||
columns = predt.shape[1]
|
||||
else:
|
||||
groups = predt.shape[1]
|
||||
columns = predt.shape[2]
|
||||
# pylint: disable=no-member
|
||||
ret = (
|
||||
[dask.delayed(predt), groups, columns],
|
||||
order,
|
||||
)
|
||||
else:
|
||||
assert len(predt.shape) == 1 or len(predt.shape) == 2
|
||||
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
|
||||
# pylint: disable=no-member
|
||||
ret = (
|
||||
[dask.delayed(predt), columns],
|
||||
order,
|
||||
)
|
||||
predictions.append(ret)
|
||||
|
||||
return predictions
|
||||
|
||||
def dispatched_get_shape(
|
||||
worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts
|
||||
) -> List[Tuple[int, int]]:
|
||||
"""Get shape of data in each worker."""
|
||||
LOGGER.debug("Get shape on %d", worker_id)
|
||||
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
|
||||
shapes = []
|
||||
for i, parts in enumerate(list_of_parts):
|
||||
(data, _, _, _, _, _, _) = parts
|
||||
shapes.append((data.shape, list_of_orders[i]))
|
||||
return shapes
|
||||
|
||||
async def map_function(
|
||||
func: Callable[[int, List[int], _DataParts], Any]
|
||||
) -> List[Any]:
|
||||
"""Run function for each part of the data."""
|
||||
futures = []
|
||||
workers_address = list(worker_map.keys())
|
||||
for wid, worker_addr in enumerate(workers_address):
|
||||
worker_addr = workers_address[wid]
|
||||
list_of_parts = worker_map[worker_addr]
|
||||
list_of_orders = [partition_order[part.key] for part in list_of_parts]
|
||||
|
||||
f = client.submit(
|
||||
func,
|
||||
worker_id=wid,
|
||||
list_of_orders=list_of_orders,
|
||||
list_of_parts=list_of_parts,
|
||||
pure=True,
|
||||
workers=[worker_addr],
|
||||
m = DMatrix(
|
||||
data,
|
||||
nthread=worker.nthreads,
|
||||
missing=missing,
|
||||
base_margin=base_margin,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
)
|
||||
assert isinstance(f, distributed.client.Future)
|
||||
futures.append(f)
|
||||
# Get delayed objects
|
||||
results = await client.gather(futures)
|
||||
# flatten into 1 dim list
|
||||
results = [t for list_per_worker in results for t in list_per_worker]
|
||||
# sort by order, l[0] is the delayed object, l[1] is its order
|
||||
results = sorted(results, key=lambda l: l[1])
|
||||
results = [predt for predt, order in results] # remove order
|
||||
return results
|
||||
predt = booster.predict(
|
||||
m,
|
||||
output_margin=output_margin,
|
||||
pred_leaf=pred_leaf,
|
||||
pred_contribs=pred_contribs,
|
||||
approx_contribs=approx_contribs,
|
||||
pred_interactions=pred_interactions,
|
||||
validate_features=validate_features,
|
||||
)
|
||||
return predt
|
||||
|
||||
results = await map_function(dispatched_predict)
|
||||
shapes = await map_function(dispatched_get_shape)
|
||||
all_parts = []
|
||||
all_orders = []
|
||||
all_shapes = []
|
||||
workers_address = list(data.worker_map.keys())
|
||||
for worker_addr in workers_address:
|
||||
list_of_parts = data.worker_map[worker_addr]
|
||||
all_parts.extend(list_of_parts)
|
||||
all_orders.extend([partition_order[part.key] for part in list_of_parts])
|
||||
for part in all_parts:
|
||||
s = client.submit(lambda part: part[0].shape[0], part)
|
||||
all_shapes.append(s)
|
||||
all_shapes = await client.gather(all_shapes)
|
||||
|
||||
parts_with_order = list(zip(all_parts, all_shapes, all_orders))
|
||||
parts_with_order = sorted(parts_with_order, key=lambda p: p[2])
|
||||
all_parts = [part for part, shape, order in parts_with_order]
|
||||
all_shapes = [shape for part, shape, order in parts_with_order]
|
||||
|
||||
futures = []
|
||||
booster_f = await client.scatter(data=_booster, broadcast=True)
|
||||
for part in all_parts:
|
||||
f = client.submit(dispatched_predict, booster_f, part)
|
||||
futures.append(f)
|
||||
|
||||
# Constructing a dask array from list of numpy arrays
|
||||
# See https://docs.dask.org/en/latest/array-creation.html
|
||||
arrays = []
|
||||
for i, shape in enumerate(shapes):
|
||||
if pred_contribs:
|
||||
out_shape = (
|
||||
(shape[0], results[i][2])
|
||||
if results[i][1] == 1
|
||||
else (shape[0], results[i][1], results[i][2])
|
||||
)
|
||||
elif pred_interactions:
|
||||
out_shape = (
|
||||
(shape[0], results[i][2], results[i][2])
|
||||
if results[i][1] == 1
|
||||
else (shape[0], results[i][1], results[i][2])
|
||||
)
|
||||
else:
|
||||
out_shape = (shape[0],) if results[i][1] == 1 else (shape[0], results[i][1])
|
||||
for i, rows in enumerate(all_shapes):
|
||||
arrays.append(
|
||||
da.from_delayed(results[i][0], shape=out_shape, dtype=numpy.float32)
|
||||
da.from_delayed(
|
||||
futures[i], shape=(rows,) + output_shape[1:], dtype=numpy.float32
|
||||
)
|
||||
)
|
||||
|
||||
predictions = await da.concatenate(arrays, axis=0)
|
||||
return predictions
|
||||
|
||||
|
||||
def predict(
|
||||
def predict( # pylint: disable=unused-argument
|
||||
client: "distributed.Client",
|
||||
model: Union[TrainReturnT, Booster],
|
||||
data: Union[DaskDMatrix, _DaskCollection],
|
||||
@@ -1190,22 +1206,15 @@ def predict(
|
||||
-------
|
||||
prediction: dask.array.Array/dask.dataframe.Series
|
||||
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an
|
||||
array, when input data is ``dask.dataframe.DataFrame``, return value is
|
||||
``dask.dataframe.Series``
|
||||
array, when input data is ``dask.dataframe.DataFrame``, return value can be
|
||||
``dask.dataframe.Series``, ``dask.dataframe.DataFrame`` or ``dask.array.Array``,
|
||||
depending on the output shape.
|
||||
|
||||
'''
|
||||
_assert_dask_support()
|
||||
client = _xgb_get_client(client)
|
||||
global_config = config.get_config()
|
||||
return client.sync(
|
||||
_predict_async, client, global_config, model, data,
|
||||
output_margin=output_margin,
|
||||
missing=missing,
|
||||
pred_leaf=pred_leaf,
|
||||
pred_contribs=pred_contribs,
|
||||
approx_contribs=approx_contribs,
|
||||
pred_interactions=pred_interactions,
|
||||
validate_features=validate_features
|
||||
_predict_async, global_config=config.get_config(), **locals()
|
||||
)
|
||||
|
||||
|
||||
@@ -1228,30 +1237,38 @@ async def _inplace_predict_async(
|
||||
if not isinstance(data, (da.Array, dd.DataFrame)):
|
||||
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
|
||||
|
||||
def mapped_predict(data: Any, is_df: bool) -> Any:
|
||||
worker = distributed.get_worker()
|
||||
config.set_config(**global_config)
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
prediction = booster.inplace_predict(
|
||||
data,
|
||||
iteration_range=iteration_range,
|
||||
predict_type=predict_type,
|
||||
missing=missing)
|
||||
if is_df:
|
||||
def mapped_predict(
|
||||
booster: Booster, data: Any, is_df: bool, columns: List[int], _: Any
|
||||
) -> Any:
|
||||
with config.config_context(**global_config):
|
||||
prediction = booster.inplace_predict(
|
||||
data,
|
||||
iteration_range=iteration_range,
|
||||
predict_type=predict_type,
|
||||
missing=missing
|
||||
)
|
||||
if is_df and len(prediction.shape) <= 2:
|
||||
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
|
||||
import cudf
|
||||
prediction = cudf.DataFrame({'prediction': prediction},
|
||||
dtype=numpy.float32)
|
||||
prediction = cudf.DataFrame(
|
||||
prediction, columns=columns, dtype=numpy.float32
|
||||
)
|
||||
else:
|
||||
# If it's from pandas, the partition is a numpy array
|
||||
prediction = DataFrame(prediction, columns=['prediction'],
|
||||
dtype=numpy.float32)
|
||||
prediction = DataFrame(
|
||||
prediction, columns=columns, dtype=numpy.float32
|
||||
)
|
||||
return prediction
|
||||
|
||||
return await _direct_predict_impl(client, data, mapped_predict)
|
||||
shape, meta = _infer_predict_output(
|
||||
booster, data, True, predict_type=predict_type, iteration_range=iteration_range
|
||||
)
|
||||
return await _direct_predict_impl(
|
||||
client, mapped_predict, booster, data, None, shape, meta
|
||||
)
|
||||
|
||||
|
||||
def inplace_predict(
|
||||
def inplace_predict( # pylint: disable=unused-argument
|
||||
client: "distributed.Client",
|
||||
model: Union[TrainReturnT, Booster],
|
||||
data: _DaskCollection,
|
||||
@@ -1281,16 +1298,17 @@ def inplace_predict(
|
||||
|
||||
Returns
|
||||
-------
|
||||
prediction
|
||||
prediction :
|
||||
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an
|
||||
array, when input data is ``dask.dataframe.DataFrame``, return value can be
|
||||
``dask.dataframe.Series``, ``dask.dataframe.DataFrame`` or ``dask.array.Array``,
|
||||
depending on the output shape.
|
||||
'''
|
||||
_assert_dask_support()
|
||||
client = _xgb_get_client(client)
|
||||
global_config = config.get_config()
|
||||
return client.sync(_inplace_predict_async, client, global_config, model=model,
|
||||
data=data,
|
||||
iteration_range=iteration_range,
|
||||
predict_type=predict_type,
|
||||
missing=missing)
|
||||
return client.sync(
|
||||
_inplace_predict_async, global_config=config.get_config(), **locals()
|
||||
)
|
||||
|
||||
|
||||
async def _async_wrap_evaluation_matrices(
|
||||
|
||||
Reference in New Issue
Block a user