[dask] Order the prediction result. (#5416)
This commit is contained in:
parent
668e432e2d
commit
21b671aa06
@ -142,9 +142,6 @@ class DaskDMatrix:
|
|||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
_feature_names = None # for previous version's pickle
|
|
||||||
_feature_types = None
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
client,
|
client,
|
||||||
data,
|
data,
|
||||||
@ -156,9 +153,9 @@ class DaskDMatrix:
|
|||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
_assert_client(client)
|
_assert_client(client)
|
||||||
|
|
||||||
self._feature_names = feature_names
|
self.feature_names = feature_names
|
||||||
self._feature_types = feature_types
|
self.feature_types = feature_types
|
||||||
self._missing = missing
|
self.missing = missing
|
||||||
|
|
||||||
if len(data.shape) != 2:
|
if len(data.shape) != 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -240,6 +237,10 @@ class DaskDMatrix:
|
|||||||
for part in parts:
|
for part in parts:
|
||||||
assert part.status == 'finished'
|
assert part.status == 'finished'
|
||||||
|
|
||||||
|
self.partition_order = {}
|
||||||
|
for i, part in enumerate(parts):
|
||||||
|
self.partition_order[part.key] = i
|
||||||
|
|
||||||
key_to_partition = {part.key: part for part in parts}
|
key_to_partition = {part.key: part for part in parts}
|
||||||
who_has = await client.scheduler.who_has(
|
who_has = await client.scheduler.who_has(
|
||||||
keys=[part.key for part in parts])
|
keys=[part.key for part in parts])
|
||||||
@ -250,6 +251,16 @@ class DaskDMatrix:
|
|||||||
|
|
||||||
self.worker_map = worker_map
|
self.worker_map = worker_map
|
||||||
|
|
||||||
|
def get_worker_x_ordered(self, worker):
|
||||||
|
list_of_parts = self.worker_map[worker.address]
|
||||||
|
client = get_client()
|
||||||
|
list_of_parts_value = client.gather(list_of_parts)
|
||||||
|
result = []
|
||||||
|
for i, part in enumerate(list_of_parts):
|
||||||
|
result.append((list_of_parts_value[i][0],
|
||||||
|
self.partition_order[part.key]))
|
||||||
|
return result
|
||||||
|
|
||||||
def get_worker_parts(self, worker):
|
def get_worker_parts(self, worker):
|
||||||
'''Get mapped parts of data in each worker.'''
|
'''Get mapped parts of data in each worker.'''
|
||||||
list_of_parts = self.worker_map[worker.address]
|
list_of_parts = self.worker_map[worker.address]
|
||||||
@ -292,8 +303,8 @@ class DaskDMatrix:
|
|||||||
workers=set(self.worker_map.keys()))
|
workers=set(self.worker_map.keys()))
|
||||||
logging.warning(msg)
|
logging.warning(msg)
|
||||||
d = DMatrix(numpy.empty((0, 0)),
|
d = DMatrix(numpy.empty((0, 0)),
|
||||||
feature_names=self._feature_names,
|
feature_names=self.feature_names,
|
||||||
feature_types=self._feature_types)
|
feature_types=self.feature_types)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
data, labels, weights = self.get_worker_parts(worker)
|
data, labels, weights = self.get_worker_parts(worker)
|
||||||
@ -311,9 +322,9 @@ class DaskDMatrix:
|
|||||||
dmatrix = DMatrix(data,
|
dmatrix = DMatrix(data,
|
||||||
labels,
|
labels,
|
||||||
weight=weights,
|
weight=weights,
|
||||||
missing=self._missing,
|
missing=self.missing,
|
||||||
feature_names=self._feature_names,
|
feature_names=self.feature_names,
|
||||||
feature_types=self._feature_types)
|
feature_types=self.feature_types)
|
||||||
return dmatrix
|
return dmatrix
|
||||||
|
|
||||||
def get_worker_data_shape(self, worker):
|
def get_worker_data_shape(self, worker):
|
||||||
@ -460,41 +471,65 @@ def predict(client, model, data, *args):
|
|||||||
worker_map = data.worker_map
|
worker_map = data.worker_map
|
||||||
client = _xgb_get_client(client)
|
client = _xgb_get_client(client)
|
||||||
|
|
||||||
rabit_args = _get_rabit_args(worker_map, client)
|
missing = data.missing
|
||||||
|
feature_names = data.feature_names
|
||||||
|
feature_types = data.feature_types
|
||||||
|
|
||||||
def dispatched_predict(worker_id):
|
def dispatched_predict(worker_id):
|
||||||
'''Perform prediction on each worker.'''
|
'''Perform prediction on each worker.'''
|
||||||
logging.info('Predicting on %d', worker_id)
|
logging.info('Predicting on %d', worker_id)
|
||||||
worker = distributed_get_worker()
|
worker = distributed_get_worker()
|
||||||
local_x = data.get_worker_data(worker)
|
list_of_parts = data.get_worker_x_ordered(worker)
|
||||||
|
predictions = []
|
||||||
with RabitContext(rabit_args):
|
for part, order in list_of_parts:
|
||||||
local_predictions = booster.predict(
|
local_x = DMatrix(part,
|
||||||
data=local_x, validate_features=local_x.num_row() != 0, *args)
|
feature_names=feature_names,
|
||||||
return local_predictions
|
feature_types=feature_types,
|
||||||
|
missing=missing)
|
||||||
futures = client.map(dispatched_predict,
|
predt = booster.predict(data=local_x,
|
||||||
range(len(worker_map)),
|
validate_features=local_x.num_row() != 0,
|
||||||
pure=False,
|
*args)
|
||||||
workers=list(worker_map.keys()))
|
ret = (delayed(predt), order)
|
||||||
|
predictions.append(ret)
|
||||||
|
return predictions
|
||||||
|
|
||||||
def dispatched_get_shape(worker_id):
|
def dispatched_get_shape(worker_id):
|
||||||
'''Get shape of data in each worker.'''
|
'''Get shape of data in each worker.'''
|
||||||
logging.info('Trying to get data shape on %d', worker_id)
|
logging.info('Trying to get data shape on %d', worker_id)
|
||||||
worker = distributed_get_worker()
|
worker = distributed_get_worker()
|
||||||
rows, _ = data.get_worker_data_shape(worker)
|
list_of_parts = data.get_worker_x_ordered(worker)
|
||||||
return rows, 1 # default is 1
|
shapes = []
|
||||||
|
for part, order in list_of_parts:
|
||||||
|
s = part.shape
|
||||||
|
shapes.append((s, order))
|
||||||
|
return shapes
|
||||||
|
|
||||||
|
def map_function(func):
|
||||||
|
'''Run function for each part of the data.'''
|
||||||
|
futures = []
|
||||||
|
for wid in range(len(worker_map)):
|
||||||
|
list_of_workers = [list(worker_map.keys())[wid]]
|
||||||
|
f = client.submit(func, wid,
|
||||||
|
pure=False,
|
||||||
|
workers=list_of_workers)
|
||||||
|
futures.append(f)
|
||||||
|
|
||||||
|
# Get delayed objects
|
||||||
|
results = client.gather(futures)
|
||||||
|
results = [t for l in results for t in l] # flatten into 1 dim list
|
||||||
|
# sort by order, l[0] is the delayed object, l[1] is its order
|
||||||
|
results = sorted(results, key=lambda l: l[1])
|
||||||
|
results = [predt for predt, order in results] # remove order
|
||||||
|
return results
|
||||||
|
|
||||||
|
results = map_function(dispatched_predict)
|
||||||
|
shapes = map_function(dispatched_get_shape)
|
||||||
|
|
||||||
# Constructing a dask array from list of numpy arrays
|
# Constructing a dask array from list of numpy arrays
|
||||||
# See https://docs.dask.org/en/latest/array-creation.html
|
# See https://docs.dask.org/en/latest/array-creation.html
|
||||||
futures_shape = client.map(dispatched_get_shape,
|
|
||||||
range(len(worker_map)),
|
|
||||||
pure=False,
|
|
||||||
workers=list(worker_map.keys()))
|
|
||||||
shapes = client.gather(futures_shape)
|
|
||||||
arrays = []
|
arrays = []
|
||||||
for i in range(len(futures_shape)):
|
for i, shape in enumerate(shapes):
|
||||||
arrays.append(da.from_delayed(futures[i], shape=(shapes[i][0], ),
|
arrays.append(da.from_delayed(results[i], shape=(shape[0], ),
|
||||||
dtype=numpy.float32))
|
dtype=numpy.float32))
|
||||||
predictions = da.concatenate(arrays, axis=0)
|
predictions = da.concatenate(arrays, axis=0)
|
||||||
return predictions
|
return predictions
|
||||||
|
|||||||
@ -74,6 +74,11 @@ def test_from_dask_array():
|
|||||||
# force prediction to be computed
|
# force prediction to be computed
|
||||||
prediction = prediction.compute()
|
prediction = prediction.compute()
|
||||||
|
|
||||||
|
single_node_predt = result['booster'].predict(
|
||||||
|
xgb.DMatrix(X.compute())
|
||||||
|
)
|
||||||
|
np.testing.assert_allclose(prediction, single_node_predt)
|
||||||
|
|
||||||
|
|
||||||
def test_dask_regressor():
|
def test_dask_regressor():
|
||||||
with LocalCluster(n_workers=5) as cluster:
|
with LocalCluster(n_workers=5) as cluster:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user