Fix dask predict (#6412)

This commit is contained in:
Jiaming Yuan 2020-11-20 10:10:52 +08:00 committed by GitHub
parent 44a9d69efb
commit a7b42adb74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -344,7 +344,7 @@ class DaskDMatrix:
'is_quantile': self.is_quantile}
def _get_worker_parts_ordered(meta_names, list_of_keys, list_of_parts, partition_order):
def _get_worker_parts_ordered(meta_names, list_of_parts):
# List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved.
assert isinstance(list_of_parts, list)
@ -372,13 +372,8 @@ def _get_worker_parts_ordered(meta_names, list_of_keys, list_of_parts, partition
label_upper_bound = blob
else:
raise ValueError('Unknown metainfo:', meta_names[j])
if partition_order:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound, partition_order[list_of_keys[i]]))
else:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound))
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound))
return result
@ -387,7 +382,7 @@ def _unzip(list_of_parts):
def _get_worker_parts(list_of_parts: List[tuple], meta_names):
partitions = _get_worker_parts_ordered(meta_names, None, list_of_parts, None)
partitions = _get_worker_parts_ordered(meta_names, list_of_parts)
partitions = _unzip(partitions)
return partitions
@ -799,19 +794,17 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
missing = data.missing
meta_names = data.meta_names
def dispatched_predict(worker_id, list_of_keys, list_of_parts):
def dispatched_predict(worker_id, list_of_orders, list_of_parts):
'''Perform prediction on each worker.'''
LOGGER.info('Predicting on %d', worker_id)
c = distributed.get_client()
list_of_keys = c.compute(list_of_keys).result()
worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(
meta_names, list_of_keys, list_of_parts, partition_order)
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
predictions = []
booster.set_param({'nthread': worker.nthreads})
for parts in list_of_parts:
(data, _, _, base_margin, _, _, order) = parts
for i, parts in enumerate(list_of_parts):
(data, _, _, base_margin, _, _) = parts
order = list_of_orders[i]
local_part = DMatrix(
data,
base_margin=base_margin,
@ -830,21 +823,14 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
return predictions
def dispatched_get_shape(worker_id, list_of_keys, list_of_parts):
def dispatched_get_shape(worker_id, list_of_orders, list_of_parts):
'''Get shape of data in each worker.'''
LOGGER.info('Get shape on %d', worker_id)
c = distributed.get_client()
list_of_keys = c.compute(list_of_keys).result()
list_of_parts = _get_worker_parts_ordered(
meta_names,
list_of_keys,
list_of_parts,
partition_order,
)
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
shapes = []
for parts in list_of_parts:
(data, _, _, _, _, _, order) = parts
shapes.append((data.shape, order))
for i, parts in enumerate(list_of_parts):
(data, _, _, _, _, _) = parts
shapes.append((data.shape, list_of_orders[i]))
return shapes
async def map_function(func):
@ -854,11 +840,13 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
for wid, worker_addr in enumerate(workers_address):
worker_addr = workers_address[wid]
list_of_parts = worker_map[worker_addr]
list_of_keys = [part.key for part in list_of_parts]
f = await client.submit(func, worker_id=wid,
list_of_keys=dask.delayed(list_of_keys),
list_of_parts=list_of_parts,
pure=False, workers=[worker_addr])
list_of_orders = [partition_order[part.key] for part in list_of_parts]
f = client.submit(func, worker_id=wid,
list_of_orders=list_of_orders,
list_of_parts=list_of_parts,
pure=True, workers=[worker_addr])
assert isinstance(f, distributed.client.Future)
futures.append(f)
# Get delayed objects
results = await client.gather(futures)