From 2f524e9f41e0c0c7b02ad4574cbb192e6be7c066 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 16 Jul 2021 04:27:05 +0800 Subject: [PATCH] [dask] Work around segfault in prediction. (#7112) --- python-package/xgboost/dask.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index a6b47906c..f203f88b2 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -1200,13 +1200,14 @@ async def _predict_async( missing = data.missing meta_names = data.meta_names - def dispatched_predict(booster: Booster, part: Any) -> numpy.ndarray: + def dispatched_predict(booster: Booster, part: Tuple) -> numpy.ndarray: data = part[0] assert isinstance(part, tuple), type(part) base_margin = None for i, blob in enumerate(part[1:]): if meta_names[i] == "base_margin": - base_margin = blob + # segfault without copy. See https://github.com/dmlc/xgboost/issues/7111. + base_margin = blob.copy() with config.config_context(**global_config): m = DMatrix( data, @@ -1231,13 +1232,15 @@ async def _predict_async( all_parts = [] all_orders = [] all_shapes = [] + all_workers: List[str] = [] workers_address = list(data.worker_map.keys()) for worker_addr in workers_address: list_of_parts = data.worker_map[worker_addr] all_parts.extend(list_of_parts) + all_workers.extend(len(list_of_parts) * [worker_addr]) all_orders.extend([partition_order[part.key] for part in list_of_parts]) - for part in all_parts: - s = client.submit(lambda part: part[0].shape[0], part) + for w, part in zip(all_workers, all_parts): + s = client.submit(lambda part: part[0].shape[0], part, workers=[w]) all_shapes.append(s) all_shapes = await client.gather(all_shapes) @@ -1247,8 +1250,8 @@ async def _predict_async( all_shapes = [shape for part, shape, order in parts_with_order] futures = [] - for part in all_parts: - f = client.submit(dispatched_predict, _booster, part) + for w, part in zip(all_workers, all_parts): + f = client.submit(dispatched_predict, _booster, part, workers=[w]) futures.append(f) # Constructing a dask array from list of numpy arrays