[dask] Order the prediction result. (#5416)
This commit is contained in:
parent
668e432e2d
commit
21b671aa06
@ -142,9 +142,6 @@ class DaskDMatrix:
|
||||
|
||||
'''
|
||||
|
||||
_feature_names = None # for previous version's pickle
|
||||
_feature_types = None
|
||||
|
||||
def __init__(self,
|
||||
client,
|
||||
data,
|
||||
@ -156,9 +153,9 @@ class DaskDMatrix:
|
||||
_assert_dask_support()
|
||||
_assert_client(client)
|
||||
|
||||
self._feature_names = feature_names
|
||||
self._feature_types = feature_types
|
||||
self._missing = missing
|
||||
self.feature_names = feature_names
|
||||
self.feature_types = feature_types
|
||||
self.missing = missing
|
||||
|
||||
if len(data.shape) != 2:
|
||||
raise ValueError(
|
||||
@ -240,6 +237,10 @@ class DaskDMatrix:
|
||||
for part in parts:
|
||||
assert part.status == 'finished'
|
||||
|
||||
self.partition_order = {}
|
||||
for i, part in enumerate(parts):
|
||||
self.partition_order[part.key] = i
|
||||
|
||||
key_to_partition = {part.key: part for part in parts}
|
||||
who_has = await client.scheduler.who_has(
|
||||
keys=[part.key for part in parts])
|
||||
@ -250,6 +251,16 @@ class DaskDMatrix:
|
||||
|
||||
self.worker_map = worker_map
|
||||
|
||||
def get_worker_x_ordered(self, worker):
|
||||
list_of_parts = self.worker_map[worker.address]
|
||||
client = get_client()
|
||||
list_of_parts_value = client.gather(list_of_parts)
|
||||
result = []
|
||||
for i, part in enumerate(list_of_parts):
|
||||
result.append((list_of_parts_value[i][0],
|
||||
self.partition_order[part.key]))
|
||||
return result
|
||||
|
||||
def get_worker_parts(self, worker):
|
||||
'''Get mapped parts of data in each worker.'''
|
||||
list_of_parts = self.worker_map[worker.address]
|
||||
@ -292,8 +303,8 @@ class DaskDMatrix:
|
||||
workers=set(self.worker_map.keys()))
|
||||
logging.warning(msg)
|
||||
d = DMatrix(numpy.empty((0, 0)),
|
||||
feature_names=self._feature_names,
|
||||
feature_types=self._feature_types)
|
||||
feature_names=self.feature_names,
|
||||
feature_types=self.feature_types)
|
||||
return d
|
||||
|
||||
data, labels, weights = self.get_worker_parts(worker)
|
||||
@ -311,9 +322,9 @@ class DaskDMatrix:
|
||||
dmatrix = DMatrix(data,
|
||||
labels,
|
||||
weight=weights,
|
||||
missing=self._missing,
|
||||
feature_names=self._feature_names,
|
||||
feature_types=self._feature_types)
|
||||
missing=self.missing,
|
||||
feature_names=self.feature_names,
|
||||
feature_types=self.feature_types)
|
||||
return dmatrix
|
||||
|
||||
def get_worker_data_shape(self, worker):
|
||||
@ -460,41 +471,65 @@ def predict(client, model, data, *args):
|
||||
worker_map = data.worker_map
|
||||
client = _xgb_get_client(client)
|
||||
|
||||
rabit_args = _get_rabit_args(worker_map, client)
|
||||
missing = data.missing
|
||||
feature_names = data.feature_names
|
||||
feature_types = data.feature_types
|
||||
|
||||
def dispatched_predict(worker_id):
|
||||
'''Perform prediction on each worker.'''
|
||||
logging.info('Predicting on %d', worker_id)
|
||||
worker = distributed_get_worker()
|
||||
local_x = data.get_worker_data(worker)
|
||||
|
||||
with RabitContext(rabit_args):
|
||||
local_predictions = booster.predict(
|
||||
data=local_x, validate_features=local_x.num_row() != 0, *args)
|
||||
return local_predictions
|
||||
|
||||
futures = client.map(dispatched_predict,
|
||||
range(len(worker_map)),
|
||||
pure=False,
|
||||
workers=list(worker_map.keys()))
|
||||
list_of_parts = data.get_worker_x_ordered(worker)
|
||||
predictions = []
|
||||
for part, order in list_of_parts:
|
||||
local_x = DMatrix(part,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
missing=missing)
|
||||
predt = booster.predict(data=local_x,
|
||||
validate_features=local_x.num_row() != 0,
|
||||
*args)
|
||||
ret = (delayed(predt), order)
|
||||
predictions.append(ret)
|
||||
return predictions
|
||||
|
||||
def dispatched_get_shape(worker_id):
|
||||
'''Get shape of data in each worker.'''
|
||||
logging.info('Trying to get data shape on %d', worker_id)
|
||||
worker = distributed_get_worker()
|
||||
rows, _ = data.get_worker_data_shape(worker)
|
||||
return rows, 1 # default is 1
|
||||
list_of_parts = data.get_worker_x_ordered(worker)
|
||||
shapes = []
|
||||
for part, order in list_of_parts:
|
||||
s = part.shape
|
||||
shapes.append((s, order))
|
||||
return shapes
|
||||
|
||||
def map_function(func):
|
||||
'''Run function for each part of the data.'''
|
||||
futures = []
|
||||
for wid in range(len(worker_map)):
|
||||
list_of_workers = [list(worker_map.keys())[wid]]
|
||||
f = client.submit(func, wid,
|
||||
pure=False,
|
||||
workers=list_of_workers)
|
||||
futures.append(f)
|
||||
|
||||
# Get delayed objects
|
||||
results = client.gather(futures)
|
||||
results = [t for l in results for t in l] # flatten into 1 dim list
|
||||
# sort by order, l[0] is the delayed object, l[1] is its order
|
||||
results = sorted(results, key=lambda l: l[1])
|
||||
results = [predt for predt, order in results] # remove order
|
||||
return results
|
||||
|
||||
results = map_function(dispatched_predict)
|
||||
shapes = map_function(dispatched_get_shape)
|
||||
|
||||
# Constructing a dask array from list of numpy arrays
|
||||
# See https://docs.dask.org/en/latest/array-creation.html
|
||||
futures_shape = client.map(dispatched_get_shape,
|
||||
range(len(worker_map)),
|
||||
pure=False,
|
||||
workers=list(worker_map.keys()))
|
||||
shapes = client.gather(futures_shape)
|
||||
arrays = []
|
||||
for i in range(len(futures_shape)):
|
||||
arrays.append(da.from_delayed(futures[i], shape=(shapes[i][0], ),
|
||||
for i, shape in enumerate(shapes):
|
||||
arrays.append(da.from_delayed(results[i], shape=(shape[0], ),
|
||||
dtype=numpy.float32))
|
||||
predictions = da.concatenate(arrays, axis=0)
|
||||
return predictions
|
||||
|
||||
@ -74,6 +74,11 @@ def test_from_dask_array():
|
||||
# force prediction to be computed
|
||||
prediction = prediction.compute()
|
||||
|
||||
single_node_predt = result['booster'].predict(
|
||||
xgb.DMatrix(X.compute())
|
||||
)
|
||||
np.testing.assert_allclose(prediction, single_node_predt)
|
||||
|
||||
|
||||
def test_dask_regressor():
|
||||
with LocalCluster(n_workers=5) as cluster:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user