[dask] Work around segfault in prediction. (#7112)
This commit is contained in:
parent
abec3dbf6d
commit
2f524e9f41
@ -1200,13 +1200,14 @@ async def _predict_async(
|
||||
missing = data.missing
|
||||
meta_names = data.meta_names
|
||||
|
||||
def dispatched_predict(booster: Booster, part: Any) -> numpy.ndarray:
|
||||
def dispatched_predict(booster: Booster, part: Tuple) -> numpy.ndarray:
|
||||
data = part[0]
|
||||
assert isinstance(part, tuple), type(part)
|
||||
base_margin = None
|
||||
for i, blob in enumerate(part[1:]):
|
||||
if meta_names[i] == "base_margin":
|
||||
base_margin = blob
|
||||
# segfault without copy. See https://github.com/dmlc/xgboost/issues/7111.
|
||||
base_margin = blob.copy()
|
||||
with config.config_context(**global_config):
|
||||
m = DMatrix(
|
||||
data,
|
||||
@ -1231,13 +1232,15 @@ async def _predict_async(
|
||||
all_parts = []
|
||||
all_orders = []
|
||||
all_shapes = []
|
||||
all_workers: List[str] = []
|
||||
workers_address = list(data.worker_map.keys())
|
||||
for worker_addr in workers_address:
|
||||
list_of_parts = data.worker_map[worker_addr]
|
||||
all_parts.extend(list_of_parts)
|
||||
all_workers.extend(len(list_of_parts) * [worker_addr])
|
||||
all_orders.extend([partition_order[part.key] for part in list_of_parts])
|
||||
for part in all_parts:
|
||||
s = client.submit(lambda part: part[0].shape[0], part)
|
||||
for w, part in zip(all_workers, all_parts):
|
||||
s = client.submit(lambda part: part[0].shape[0], part, workers=[w])
|
||||
all_shapes.append(s)
|
||||
all_shapes = await client.gather(all_shapes)
|
||||
|
||||
@ -1247,8 +1250,8 @@ async def _predict_async(
|
||||
all_shapes = [shape for part, shape, order in parts_with_order]
|
||||
|
||||
futures = []
|
||||
for part in all_parts:
|
||||
f = client.submit(dispatched_predict, _booster, part)
|
||||
for w, part in zip(all_workers, all_parts):
|
||||
f = client.submit(dispatched_predict, _booster, part, workers=[w])
|
||||
futures.append(f)
|
||||
|
||||
# Constructing a dask array from list of numpy arrays
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user