[dask] Fix prediction on DaskDMatrix with multiple meta data. (#6333)

* Unify the meta handling methods.
This commit is contained in:
Jiaming Yuan 2020-11-02 19:18:44 -05:00 committed by GitHub
parent 5a7b3592ed
commit 7756192906
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 85 additions and 58 deletions

View File

@ -18,6 +18,7 @@ import logging
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence from collections.abc import Sequence
from threading import Thread from threading import Thread
from typing import List
import numpy import numpy
@ -300,8 +301,13 @@ class DaskDMatrix:
append_meta(margin_parts, 'base_margin') append_meta(margin_parts, 'base_margin')
append_meta(ll_parts, 'label_lower_bound') append_meta(ll_parts, 'label_lower_bound')
append_meta(lu_parts, 'label_upper_bound') append_meta(lu_parts, 'label_upper_bound')
# At this point, `parts` looks like:
# [(x0, x1, ..), (y0, y1, ..), ..] in delayed form
# delay the zipped result
parts = list(map(dask.delayed, zip(*parts))) parts = list(map(dask.delayed, zip(*parts)))
# At this point, the mental model should look like:
# [(x0, y0, ..), (x1, y1, ..), ..] in delayed form
parts = client.compute(parts) parts = client.compute(parts)
await distributed.wait(parts) # async wait for parts to be computed await distributed.wait(parts) # async wait for parts to be computed
@ -309,6 +315,7 @@ class DaskDMatrix:
for part in parts: for part in parts:
assert part.status == 'finished' assert part.status == 'finished'
# Preserving the partition order for prediction.
self.partition_order = {} self.partition_order = {}
for i, part in enumerate(parts): for i, part in enumerate(parts):
self.partition_order[part.key] = i self.partition_order[part.key] = i
@ -339,59 +346,55 @@ class DaskDMatrix:
'is_quantile': self.is_quantile} 'is_quantile': self.is_quantile}
def _get_worker_parts_ordered(has_base_margin, worker_map, partition_order, def _get_worker_parts_ordered(meta_names, worker_map, partition_order, worker):
worker): list_of_parts: List[tuple] = worker_map[worker.address]
list_of_parts = worker_map[worker.address] # List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved.
client = distributed.get_client() assert isinstance(list_of_parts, list)
list_of_parts_value = client.gather(list_of_parts) with distributed.worker_client() as client:
list_of_parts_value = client.gather(list_of_parts)
result = [] result = []
for i, part in enumerate(list_of_parts): for i, part in enumerate(list_of_parts):
data = list_of_parts_value[i][0] data = list_of_parts_value[i][0]
if has_base_margin: labels = None
base_margin = list_of_parts_value[i][1] weights = None
else:
base_margin = None base_margin = None
result.append((data, base_margin, partition_order[part.key])) label_lower_bound = None
label_upper_bound = None
# Iterate through all possible meta info, brings small overhead as in xgboost
# there are constant number of meta info available.
for j, blob in enumerate(list_of_parts_value[i][1:]):
if meta_names[j] == 'labels':
labels = blob
elif meta_names[j] == 'weights':
weights = blob
elif meta_names[j] == 'base_margin':
base_margin = blob
elif meta_names[j] == 'label_lower_bound':
label_lower_bound = blob
elif meta_names[j] == 'label_upper_bound':
label_upper_bound = blob
else:
raise ValueError('Unknown metainfo:', meta_names[j])
return result if partition_order:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound, partition_order[part.key]))
else:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound))
return result
def _unzip(list_of_parts):
return list(zip(*list_of_parts))
def _get_worker_parts(worker_map, meta_names, worker): def _get_worker_parts(worker_map, meta_names, worker):
'''Get mapped parts of data in each worker from DaskDMatrix.''' partitions = _get_worker_parts_ordered(meta_names, worker_map, None, worker)
list_of_parts = worker_map[worker.address] partitions = _unzip(partitions)
assert list_of_parts, 'data in ' + worker.address + ' was moved.' return partitions
assert isinstance(list_of_parts, list)
# `_get_worker_parts` is launched inside worker. In dask side
# this should be equal to `worker._get_client`.
client = distributed.get_client()
list_of_parts = client.gather(list_of_parts)
data = None
labels = None
weights = None
base_margin = None
label_lower_bound = None
label_upper_bound = None
local_data = list(zip(*list_of_parts))
data = local_data[0]
for i, part in enumerate(local_data[1:]):
if meta_names[i] == 'labels':
labels = part
if meta_names[i] == 'weights':
weights = part
if meta_names[i] == 'base_margin':
base_margin = part
if meta_names[i] == 'label_lower_bound':
label_lower_bound = part
if meta_names[i] == 'label_upper_bound':
label_upper_bound = part
return (data, labels, weights, base_margin, label_lower_bound,
label_upper_bound)
class DaskPartitionIter(DataIter): # pylint: disable=R0902 class DaskPartitionIter(DataIter): # pylint: disable=R0902
@ -585,9 +588,9 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing,
return d return d
def concat_or_none(data): def concat_or_none(data):
if data is not None: if all([part is None for part in data]):
return concat(data) return None
return data return concat(data)
(data, labels, weights, base_margin, (data, labels, weights, base_margin,
label_lower_bound, label_upper_bound) = _get_worker_parts( label_lower_bound, label_upper_bound) = _get_worker_parts(
@ -795,7 +798,7 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
feature_names = data.feature_names feature_names = data.feature_names
feature_types = data.feature_types feature_types = data.feature_types
missing = data.missing missing = data.missing
has_margin = "base_margin" in data.meta_names meta_names = data.meta_names
def dispatched_predict(worker_id): def dispatched_predict(worker_id):
'''Perform prediction on each worker.''' '''Perform prediction on each worker.'''
@ -803,10 +806,11 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
worker = distributed.get_worker() worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered( list_of_parts = _get_worker_parts_ordered(
has_margin, worker_map, partition_order, worker) meta_names, worker_map, partition_order, worker)
predictions = [] predictions = []
booster.set_param({'nthread': worker.nthreads}) booster.set_param({'nthread': worker.nthreads})
for data, base_margin, order in list_of_parts: for parts in list_of_parts:
(data, _, _, base_margin, _, _, order) = parts
local_part = DMatrix( local_part = DMatrix(
data, data,
base_margin=base_margin, base_margin=base_margin,
@ -829,12 +833,15 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
LOGGER.info('Get shape on %d', worker_id) LOGGER.info('Get shape on %d', worker_id)
worker = distributed.get_worker() worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered( list_of_parts = _get_worker_parts_ordered(
False, meta_names,
worker_map, worker_map,
partition_order, partition_order,
worker worker
) )
shapes = [(part.shape, order) for part, _, order in list_of_parts] shapes = []
for parts in list_of_parts:
(data, _, _, _, _, _, order) = parts
shapes.append((data.shape, order))
return shapes return shapes
async def map_function(func): async def map_function(func):
@ -974,8 +981,7 @@ def inplace_predict(client, model, data,
missing=missing) missing=missing)
async def _evaluation_matrices(client, validation_set, async def _evaluation_matrices(client, validation_set, sample_weight, missing):
sample_weight, missing):
''' '''
Parameters Parameters
---------- ----------
@ -998,8 +1004,7 @@ async def _evaluation_matrices(client, validation_set,
if validation_set is not None: if validation_set is not None:
assert isinstance(validation_set, list) assert isinstance(validation_set, list)
for i, e in enumerate(validation_set): for i, e in enumerate(validation_set):
w = (sample_weight[i] w = (sample_weight[i] if sample_weight is not None else None)
if sample_weight is not None else None)
dmat = await DaskDMatrix(client=client, data=e[0], label=e[1], dmat = await DaskDMatrix(client=client, data=e[0], label=e[1],
weight=w, missing=missing) weight=w, missing=missing)
evals.append((dmat, 'validation_{}'.format(i))) evals.append((dmat, 'validation_{}'.format(i)))

View File

@ -566,6 +566,28 @@ def test_predict():
assert shap.shape[1] == kCols + 1 assert shap.shape[1] == kCols + 1
def test_predict_with_meta(client):
X, y, w = generate_array(with_weights=True)
partition_size = 20
margin = da.random.random(kRows, partition_size) + 1e4
dtrain = DaskDMatrix(client, X, y, weight=w, base_margin=margin)
booster = xgb.dask.train(
client, {}, dtrain, num_boost_round=4)['booster']
prediction = xgb.dask.predict(client, model=booster, data=dtrain)
assert prediction.ndim == 1
assert prediction.shape[0] == kRows
prediction = client.compute(prediction).result()
assert np.all(prediction > 1e3)
m = xgb.DMatrix(X.compute())
m.set_info(label=y.compute(), weight=w.compute(), base_margin=margin.compute())
single = booster.predict(m) # Make sure the ordering is correct.
assert np.all(prediction == single)
def run_aft_survival(client, dmatrix_t): def run_aft_survival(client, dmatrix_t):
# survival doesn't handle empty dataset well. # survival doesn't handle empty dataset well.
df = dd.read_csv(os.path.join(tm.PROJECT_ROOT, 'demo', 'data', df = dd.read_csv(os.path.join(tm.PROJECT_ROOT, 'demo', 'data',