[dask] Fix prediction on DaskDMatrix with multiple meta data. (#6333)

* Unify the meta handling methods.
This commit is contained in:
Jiaming Yuan 2020-11-02 19:18:44 -05:00 committed by GitHub
parent 5a7b3592ed
commit 7756192906
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 85 additions and 58 deletions

View File

@ -18,6 +18,7 @@ import logging
from collections import defaultdict
from collections.abc import Sequence
from threading import Thread
from typing import List
import numpy
@ -300,8 +301,13 @@ class DaskDMatrix:
append_meta(margin_parts, 'base_margin')
append_meta(ll_parts, 'label_lower_bound')
append_meta(lu_parts, 'label_upper_bound')
# At this point, `parts` looks like:
# [(x0, x1, ..), (y0, y1, ..), ..] in delayed form
# delay the zipped result
parts = list(map(dask.delayed, zip(*parts)))
# At this point, the mental model should look like:
# [(x0, y0, ..), (x1, y1, ..), ..] in delayed form
parts = client.compute(parts)
await distributed.wait(parts) # async wait for parts to be computed
@ -309,6 +315,7 @@ class DaskDMatrix:
for part in parts:
assert part.status == 'finished'
# Preserving the partition order for prediction.
self.partition_order = {}
for i, part in enumerate(parts):
self.partition_order[part.key] = i
@ -339,59 +346,55 @@ class DaskDMatrix:
'is_quantile': self.is_quantile}
def _get_worker_parts_ordered(has_base_margin, worker_map, partition_order,
worker):
list_of_parts = worker_map[worker.address]
client = distributed.get_client()
def _get_worker_parts_ordered(meta_names, worker_map, partition_order, worker):
list_of_parts: List[tuple] = worker_map[worker.address]
# List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved.
assert isinstance(list_of_parts, list)
with distributed.worker_client() as client:
list_of_parts_value = client.gather(list_of_parts)
result = []
for i, part in enumerate(list_of_parts):
data = list_of_parts_value[i][0]
if has_base_margin:
base_margin = list_of_parts_value[i][1]
else:
base_margin = None
result.append((data, base_margin, partition_order[part.key]))
return result
def _get_worker_parts(worker_map, meta_names, worker):
'''Get mapped parts of data in each worker from DaskDMatrix.'''
list_of_parts = worker_map[worker.address]
assert list_of_parts, 'data in ' + worker.address + ' was moved.'
assert isinstance(list_of_parts, list)
# `_get_worker_parts` is launched inside worker. In dask side
# this should be equal to `worker._get_client`.
client = distributed.get_client()
list_of_parts = client.gather(list_of_parts)
data = None
labels = None
weights = None
base_margin = None
label_lower_bound = None
label_upper_bound = None
# Iterate through all possible meta info, brings small overhead as in xgboost
# there are constant number of meta info available.
for j, blob in enumerate(list_of_parts_value[i][1:]):
if meta_names[j] == 'labels':
labels = blob
elif meta_names[j] == 'weights':
weights = blob
elif meta_names[j] == 'base_margin':
base_margin = blob
elif meta_names[j] == 'label_lower_bound':
label_lower_bound = blob
elif meta_names[j] == 'label_upper_bound':
label_upper_bound = blob
else:
raise ValueError('Unknown metainfo:', meta_names[j])
local_data = list(zip(*list_of_parts))
data = local_data[0]
if partition_order:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound, partition_order[part.key]))
else:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound))
return result
for i, part in enumerate(local_data[1:]):
if meta_names[i] == 'labels':
labels = part
if meta_names[i] == 'weights':
weights = part
if meta_names[i] == 'base_margin':
base_margin = part
if meta_names[i] == 'label_lower_bound':
label_lower_bound = part
if meta_names[i] == 'label_upper_bound':
label_upper_bound = part
return (data, labels, weights, base_margin, label_lower_bound,
label_upper_bound)
def _unzip(list_of_parts):
return list(zip(*list_of_parts))
def _get_worker_parts(worker_map, meta_names, worker):
partitions = _get_worker_parts_ordered(meta_names, worker_map, None, worker)
partitions = _unzip(partitions)
return partitions
class DaskPartitionIter(DataIter): # pylint: disable=R0902
@ -585,9 +588,9 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing,
return d
def concat_or_none(data):
if data is not None:
if all([part is None for part in data]):
return None
return concat(data)
return data
(data, labels, weights, base_margin,
label_lower_bound, label_upper_bound) = _get_worker_parts(
@ -795,7 +798,7 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
feature_names = data.feature_names
feature_types = data.feature_types
missing = data.missing
has_margin = "base_margin" in data.meta_names
meta_names = data.meta_names
def dispatched_predict(worker_id):
'''Perform prediction on each worker.'''
@ -803,10 +806,11 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(
has_margin, worker_map, partition_order, worker)
meta_names, worker_map, partition_order, worker)
predictions = []
booster.set_param({'nthread': worker.nthreads})
for data, base_margin, order in list_of_parts:
for parts in list_of_parts:
(data, _, _, base_margin, _, _, order) = parts
local_part = DMatrix(
data,
base_margin=base_margin,
@ -829,12 +833,15 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
LOGGER.info('Get shape on %d', worker_id)
worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(
False,
meta_names,
worker_map,
partition_order,
worker
)
shapes = [(part.shape, order) for part, _, order in list_of_parts]
shapes = []
for parts in list_of_parts:
(data, _, _, _, _, _, order) = parts
shapes.append((data.shape, order))
return shapes
async def map_function(func):
@ -974,8 +981,7 @@ def inplace_predict(client, model, data,
missing=missing)
async def _evaluation_matrices(client, validation_set,
sample_weight, missing):
async def _evaluation_matrices(client, validation_set, sample_weight, missing):
'''
Parameters
----------
@ -998,8 +1004,7 @@ async def _evaluation_matrices(client, validation_set,
if validation_set is not None:
assert isinstance(validation_set, list)
for i, e in enumerate(validation_set):
w = (sample_weight[i]
if sample_weight is not None else None)
w = (sample_weight[i] if sample_weight is not None else None)
dmat = await DaskDMatrix(client=client, data=e[0], label=e[1],
weight=w, missing=missing)
evals.append((dmat, 'validation_{}'.format(i)))

View File

@ -566,6 +566,28 @@ def test_predict():
assert shap.shape[1] == kCols + 1
def test_predict_with_meta(client):
X, y, w = generate_array(with_weights=True)
partition_size = 20
margin = da.random.random(kRows, partition_size) + 1e4
dtrain = DaskDMatrix(client, X, y, weight=w, base_margin=margin)
booster = xgb.dask.train(
client, {}, dtrain, num_boost_round=4)['booster']
prediction = xgb.dask.predict(client, model=booster, data=dtrain)
assert prediction.ndim == 1
assert prediction.shape[0] == kRows
prediction = client.compute(prediction).result()
assert np.all(prediction > 1e3)
m = xgb.DMatrix(X.compute())
m.set_info(label=y.compute(), weight=w.compute(), base_margin=margin.compute())
single = booster.predict(m) # Make sure the ordering is correct.
assert np.all(prediction == single)
def run_aft_survival(client, dmatrix_t):
# survival doesn't handle empty dataset well.
df = dd.read_csv(os.path.join(tm.PROJECT_ROOT, 'demo', 'data',