Fix dask predict (#6412)
This commit is contained in:
parent
44a9d69efb
commit
a7b42adb74
@ -344,7 +344,7 @@ class DaskDMatrix:
|
||||
'is_quantile': self.is_quantile}
|
||||
|
||||
|
||||
def _get_worker_parts_ordered(meta_names, list_of_keys, list_of_parts, partition_order):
|
||||
def _get_worker_parts_ordered(meta_names, list_of_parts):
|
||||
# List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved.
|
||||
assert isinstance(list_of_parts, list)
|
||||
|
||||
@ -372,13 +372,8 @@ def _get_worker_parts_ordered(meta_names, list_of_keys, list_of_parts, partition
|
||||
label_upper_bound = blob
|
||||
else:
|
||||
raise ValueError('Unknown metainfo:', meta_names[j])
|
||||
|
||||
if partition_order:
|
||||
result.append((data, labels, weights, base_margin, label_lower_bound,
|
||||
label_upper_bound, partition_order[list_of_keys[i]]))
|
||||
else:
|
||||
result.append((data, labels, weights, base_margin, label_lower_bound,
|
||||
label_upper_bound))
|
||||
result.append((data, labels, weights, base_margin, label_lower_bound,
|
||||
label_upper_bound))
|
||||
return result
|
||||
|
||||
|
||||
@ -387,7 +382,7 @@ def _unzip(list_of_parts):
|
||||
|
||||
|
||||
def _get_worker_parts(list_of_parts: List[tuple], meta_names):
|
||||
partitions = _get_worker_parts_ordered(meta_names, None, list_of_parts, None)
|
||||
partitions = _get_worker_parts_ordered(meta_names, list_of_parts)
|
||||
partitions = _unzip(partitions)
|
||||
return partitions
|
||||
|
||||
@ -799,19 +794,17 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
|
||||
missing = data.missing
|
||||
meta_names = data.meta_names
|
||||
|
||||
def dispatched_predict(worker_id, list_of_keys, list_of_parts):
|
||||
def dispatched_predict(worker_id, list_of_orders, list_of_parts):
|
||||
'''Perform prediction on each worker.'''
|
||||
LOGGER.info('Predicting on %d', worker_id)
|
||||
c = distributed.get_client()
|
||||
list_of_keys = c.compute(list_of_keys).result()
|
||||
worker = distributed.get_worker()
|
||||
list_of_parts = _get_worker_parts_ordered(
|
||||
meta_names, list_of_keys, list_of_parts, partition_order)
|
||||
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
|
||||
predictions = []
|
||||
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
for parts in list_of_parts:
|
||||
(data, _, _, base_margin, _, _, order) = parts
|
||||
for i, parts in enumerate(list_of_parts):
|
||||
(data, _, _, base_margin, _, _) = parts
|
||||
order = list_of_orders[i]
|
||||
local_part = DMatrix(
|
||||
data,
|
||||
base_margin=base_margin,
|
||||
@ -830,21 +823,14 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
|
||||
|
||||
return predictions
|
||||
|
||||
def dispatched_get_shape(worker_id, list_of_keys, list_of_parts):
|
||||
def dispatched_get_shape(worker_id, list_of_orders, list_of_parts):
|
||||
'''Get shape of data in each worker.'''
|
||||
LOGGER.info('Get shape on %d', worker_id)
|
||||
c = distributed.get_client()
|
||||
list_of_keys = c.compute(list_of_keys).result()
|
||||
list_of_parts = _get_worker_parts_ordered(
|
||||
meta_names,
|
||||
list_of_keys,
|
||||
list_of_parts,
|
||||
partition_order,
|
||||
)
|
||||
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
|
||||
shapes = []
|
||||
for parts in list_of_parts:
|
||||
(data, _, _, _, _, _, order) = parts
|
||||
shapes.append((data.shape, order))
|
||||
for i, parts in enumerate(list_of_parts):
|
||||
(data, _, _, _, _, _) = parts
|
||||
shapes.append((data.shape, list_of_orders[i]))
|
||||
return shapes
|
||||
|
||||
async def map_function(func):
|
||||
@ -854,11 +840,13 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
|
||||
for wid, worker_addr in enumerate(workers_address):
|
||||
worker_addr = workers_address[wid]
|
||||
list_of_parts = worker_map[worker_addr]
|
||||
list_of_keys = [part.key for part in list_of_parts]
|
||||
f = await client.submit(func, worker_id=wid,
|
||||
list_of_keys=dask.delayed(list_of_keys),
|
||||
list_of_parts=list_of_parts,
|
||||
pure=False, workers=[worker_addr])
|
||||
list_of_orders = [partition_order[part.key] for part in list_of_parts]
|
||||
|
||||
f = client.submit(func, worker_id=wid,
|
||||
list_of_orders=list_of_orders,
|
||||
list_of_parts=list_of_parts,
|
||||
pure=True, workers=[worker_addr])
|
||||
assert isinstance(f, distributed.client.Future)
|
||||
futures.append(f)
|
||||
# Get delayed objects
|
||||
results = await client.gather(futures)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user