Fix dask predict (#6412)

This commit is contained in:
Jiaming Yuan 2020-11-20 10:10:52 +08:00 committed by GitHub
parent 44a9d69efb
commit a7b42adb74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -344,7 +344,7 @@ class DaskDMatrix:
'is_quantile': self.is_quantile} 'is_quantile': self.is_quantile}
def _get_worker_parts_ordered(meta_names, list_of_keys, list_of_parts, partition_order): def _get_worker_parts_ordered(meta_names, list_of_parts):
# List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved. # List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved.
assert isinstance(list_of_parts, list) assert isinstance(list_of_parts, list)
@ -372,13 +372,8 @@ def _get_worker_parts_ordered(meta_names, list_of_keys, list_of_parts, partition
label_upper_bound = blob label_upper_bound = blob
else: else:
raise ValueError('Unknown metainfo:', meta_names[j]) raise ValueError('Unknown metainfo:', meta_names[j])
result.append((data, labels, weights, base_margin, label_lower_bound,
if partition_order: label_upper_bound))
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound, partition_order[list_of_keys[i]]))
else:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound))
return result return result
@ -387,7 +382,7 @@ def _unzip(list_of_parts):
def _get_worker_parts(list_of_parts: List[tuple], meta_names): def _get_worker_parts(list_of_parts: List[tuple], meta_names):
partitions = _get_worker_parts_ordered(meta_names, None, list_of_parts, None) partitions = _get_worker_parts_ordered(meta_names, list_of_parts)
partitions = _unzip(partitions) partitions = _unzip(partitions)
return partitions return partitions
@ -799,19 +794,17 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
missing = data.missing missing = data.missing
meta_names = data.meta_names meta_names = data.meta_names
def dispatched_predict(worker_id, list_of_keys, list_of_parts): def dispatched_predict(worker_id, list_of_orders, list_of_parts):
'''Perform prediction on each worker.''' '''Perform prediction on each worker.'''
LOGGER.info('Predicting on %d', worker_id) LOGGER.info('Predicting on %d', worker_id)
c = distributed.get_client()
list_of_keys = c.compute(list_of_keys).result()
worker = distributed.get_worker() worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered( list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
meta_names, list_of_keys, list_of_parts, partition_order)
predictions = [] predictions = []
booster.set_param({'nthread': worker.nthreads}) booster.set_param({'nthread': worker.nthreads})
for parts in list_of_parts: for i, parts in enumerate(list_of_parts):
(data, _, _, base_margin, _, _, order) = parts (data, _, _, base_margin, _, _) = parts
order = list_of_orders[i]
local_part = DMatrix( local_part = DMatrix(
data, data,
base_margin=base_margin, base_margin=base_margin,
@ -830,21 +823,14 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
return predictions return predictions
def dispatched_get_shape(worker_id, list_of_keys, list_of_parts): def dispatched_get_shape(worker_id, list_of_orders, list_of_parts):
'''Get shape of data in each worker.''' '''Get shape of data in each worker.'''
LOGGER.info('Get shape on %d', worker_id) LOGGER.info('Get shape on %d', worker_id)
c = distributed.get_client() list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
list_of_keys = c.compute(list_of_keys).result()
list_of_parts = _get_worker_parts_ordered(
meta_names,
list_of_keys,
list_of_parts,
partition_order,
)
shapes = [] shapes = []
for parts in list_of_parts: for i, parts in enumerate(list_of_parts):
(data, _, _, _, _, _, order) = parts (data, _, _, _, _, _) = parts
shapes.append((data.shape, order)) shapes.append((data.shape, list_of_orders[i]))
return shapes return shapes
async def map_function(func): async def map_function(func):
@ -854,11 +840,13 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
for wid, worker_addr in enumerate(workers_address): for wid, worker_addr in enumerate(workers_address):
worker_addr = workers_address[wid] worker_addr = workers_address[wid]
list_of_parts = worker_map[worker_addr] list_of_parts = worker_map[worker_addr]
list_of_keys = [part.key for part in list_of_parts] list_of_orders = [partition_order[part.key] for part in list_of_parts]
f = await client.submit(func, worker_id=wid,
list_of_keys=dask.delayed(list_of_keys), f = client.submit(func, worker_id=wid,
list_of_parts=list_of_parts, list_of_orders=list_of_orders,
pure=False, workers=[worker_addr]) list_of_parts=list_of_parts,
pure=True, workers=[worker_addr])
assert isinstance(f, distributed.client.Future)
futures.append(f) futures.append(f)
# Get delayed objects # Get delayed objects
results = await client.gather(futures) results = await client.gather(futures)