[dask] Work around segfault in prediction. (#7112)

This commit is contained in:
Jiaming Yuan 2021-07-16 04:27:05 +08:00 committed by GitHub
parent abec3dbf6d
commit 2f524e9f41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1200,13 +1200,14 @@ async def _predict_async(
missing = data.missing
meta_names = data.meta_names
def dispatched_predict(booster: Booster, part: Any) -> numpy.ndarray:
def dispatched_predict(booster: Booster, part: Tuple) -> numpy.ndarray:
data = part[0]
assert isinstance(part, tuple), type(part)
base_margin = None
for i, blob in enumerate(part[1:]):
if meta_names[i] == "base_margin":
base_margin = blob
# segfault without copy. See https://github.com/dmlc/xgboost/issues/7111.
base_margin = blob.copy()
with config.config_context(**global_config):
m = DMatrix(
data,
@ -1231,13 +1232,15 @@ async def _predict_async(
all_parts = []
all_orders = []
all_shapes = []
all_workers: List[str] = []
workers_address = list(data.worker_map.keys())
for worker_addr in workers_address:
list_of_parts = data.worker_map[worker_addr]
all_parts.extend(list_of_parts)
all_workers.extend(len(list_of_parts) * [worker_addr])
all_orders.extend([partition_order[part.key] for part in list_of_parts])
for part in all_parts:
s = client.submit(lambda part: part[0].shape[0], part)
for w, part in zip(all_workers, all_parts):
s = client.submit(lambda part: part[0].shape[0], part, workers=[w])
all_shapes.append(s)
all_shapes = await client.gather(all_shapes)
@ -1247,8 +1250,8 @@ async def _predict_async(
all_shapes = [shape for part, shape, order in parts_with_order]
futures = []
for part in all_parts:
f = client.submit(dispatched_predict, _booster, part)
for w, part in zip(all_workers, all_parts):
f = client.submit(dispatched_predict, _booster, part, workers=[w])
futures.append(f)
# Constructing a dask array from list of numpy arrays