From 7756192906108d7dd5dfb1cd24e10a65f0cff51f Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 2 Nov 2020 19:18:44 -0500 Subject: [PATCH] [dask] Fix prediction on `DaskDMatrix` with multiple meta data. (#6333) * Unify the meta handling methods. --- python-package/xgboost/dask.py | 121 +++++++++++++++++---------------- tests/python/test_with_dask.py | 22 ++++++ 2 files changed, 85 insertions(+), 58 deletions(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 7ae24f9d3..8dd336263 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -18,6 +18,7 @@ import logging from collections import defaultdict from collections.abc import Sequence from threading import Thread +from typing import List import numpy @@ -300,8 +301,13 @@ class DaskDMatrix: append_meta(margin_parts, 'base_margin') append_meta(ll_parts, 'label_lower_bound') append_meta(lu_parts, 'label_upper_bound') + # At this point, `parts` looks like: + # [(x0, x1, ..), (y0, y1, ..), ..] in delayed form + # delay the zipped result parts = list(map(dask.delayed, zip(*parts))) + # At this point, the mental model should look like: + # [(x0, y0, ..), (x1, y1, ..), ..] in delayed form parts = client.compute(parts) await distributed.wait(parts) # async wait for parts to be computed @@ -309,6 +315,7 @@ class DaskDMatrix: for part in parts: assert part.status == 'finished' + # Preserving the partition order for prediction. self.partition_order = {} for i, part in enumerate(parts): self.partition_order[part.key] = i @@ -339,59 +346,55 @@ class DaskDMatrix: 'is_quantile': self.is_quantile} -def _get_worker_parts_ordered(has_base_margin, worker_map, partition_order, - worker): - list_of_parts = worker_map[worker.address] - client = distributed.get_client() - list_of_parts_value = client.gather(list_of_parts) +def _get_worker_parts_ordered(meta_names, worker_map, partition_order, worker): + list_of_parts: List[tuple] = worker_map[worker.address] + # List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved. + assert isinstance(list_of_parts, list) + with distributed.worker_client() as client: + list_of_parts_value = client.gather(list_of_parts) - result = [] + result = [] - for i, part in enumerate(list_of_parts): - data = list_of_parts_value[i][0] - if has_base_margin: - base_margin = list_of_parts_value[i][1] - else: + for i, part in enumerate(list_of_parts): + data = list_of_parts_value[i][0] + labels = None + weights = None base_margin = None - result.append((data, base_margin, partition_order[part.key])) + label_lower_bound = None + label_upper_bound = None + # Iterate through all possible meta info, brings small overhead as in xgboost + # there are constant number of meta info available. + for j, blob in enumerate(list_of_parts_value[i][1:]): + if meta_names[j] == 'labels': + labels = blob + elif meta_names[j] == 'weights': + weights = blob + elif meta_names[j] == 'base_margin': + base_margin = blob + elif meta_names[j] == 'label_lower_bound': + label_lower_bound = blob + elif meta_names[j] == 'label_upper_bound': + label_upper_bound = blob + else: + raise ValueError('Unknown metainfo:', meta_names[j]) - return result + if partition_order: + result.append((data, labels, weights, base_margin, label_lower_bound, + label_upper_bound, partition_order[part.key])) + else: + result.append((data, labels, weights, base_margin, label_lower_bound, + label_upper_bound)) + return result + + +def _unzip(list_of_parts): + return list(zip(*list_of_parts)) def _get_worker_parts(worker_map, meta_names, worker): - '''Get mapped parts of data in each worker from DaskDMatrix.''' - list_of_parts = worker_map[worker.address] - assert list_of_parts, 'data in ' + worker.address + ' was moved.' - assert isinstance(list_of_parts, list) - - # `_get_worker_parts` is launched inside worker. In dask side - # this should be equal to `worker._get_client`. - client = distributed.get_client() - list_of_parts = client.gather(list_of_parts) - data = None - labels = None - weights = None - base_margin = None - label_lower_bound = None - label_upper_bound = None - - local_data = list(zip(*list_of_parts)) - data = local_data[0] - - for i, part in enumerate(local_data[1:]): - if meta_names[i] == 'labels': - labels = part - if meta_names[i] == 'weights': - weights = part - if meta_names[i] == 'base_margin': - base_margin = part - if meta_names[i] == 'label_lower_bound': - label_lower_bound = part - if meta_names[i] == 'label_upper_bound': - label_upper_bound = part - - return (data, labels, weights, base_margin, label_lower_bound, - label_upper_bound) + partitions = _get_worker_parts_ordered(meta_names, worker_map, None, worker) + partitions = _unzip(partitions) + return partitions class DaskPartitionIter(DataIter): # pylint: disable=R0902 @@ -585,9 +588,9 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing, return d def concat_or_none(data): - if data is not None: - return concat(data) - return data + if all([part is None for part in data]): + return None + return concat(data) (data, labels, weights, base_margin, label_lower_bound, label_upper_bound) = _get_worker_parts( @@ -795,7 +798,7 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs): feature_names = data.feature_names feature_types = data.feature_types missing = data.missing - has_margin = "base_margin" in data.meta_names + meta_names = data.meta_names def dispatched_predict(worker_id): '''Perform prediction on each worker.''' @@ -803,10 +806,11 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs): worker = distributed.get_worker() list_of_parts = _get_worker_parts_ordered( - has_margin, worker_map, partition_order, worker) + meta_names, worker_map, partition_order, worker) predictions = [] booster.set_param({'nthread': worker.nthreads}) - for data, base_margin, order in list_of_parts: + for parts in list_of_parts: + (data, _, _, base_margin, _, _, order) = parts local_part = DMatrix( data, base_margin=base_margin, @@ -829,12 +833,15 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs): LOGGER.info('Get shape on %d', worker_id) worker = distributed.get_worker() list_of_parts = _get_worker_parts_ordered( - False, + meta_names, worker_map, partition_order, worker ) - shapes = [(part.shape, order) for part, _, order in list_of_parts] + shapes = [] + for parts in list_of_parts: + (data, _, _, _, _, _, order) = parts + shapes.append((data.shape, order)) return shapes async def map_function(func): @@ -974,8 +981,7 @@ def inplace_predict(client, model, data, missing=missing) -async def _evaluation_matrices(client, validation_set, - sample_weight, missing): +async def _evaluation_matrices(client, validation_set, sample_weight, missing): ''' Parameters ---------- @@ -998,8 +1004,7 @@ async def _evaluation_matrices(client, validation_set, if validation_set is not None: assert isinstance(validation_set, list) for i, e in enumerate(validation_set): - w = (sample_weight[i] - if sample_weight is not None else None) + w = (sample_weight[i] if sample_weight is not None else None) dmat = await DaskDMatrix(client=client, data=e[0], label=e[1], weight=w, missing=missing) evals.append((dmat, 'validation_{}'.format(i))) diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 3bcc5c865..9eff943b5 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -566,6 +566,28 @@ def test_predict(): assert shap.shape[1] == kCols + 1 +def test_predict_with_meta(client): + X, y, w = generate_array(with_weights=True) + partition_size = 20 + margin = da.random.random(kRows, partition_size) + 1e4 + + dtrain = DaskDMatrix(client, X, y, weight=w, base_margin=margin) + booster = xgb.dask.train( + client, {}, dtrain, num_boost_round=4)['booster'] + + prediction = xgb.dask.predict(client, model=booster, data=dtrain) + assert prediction.ndim == 1 + assert prediction.shape[0] == kRows + + prediction = client.compute(prediction).result() + assert np.all(prediction > 1e3) + + m = xgb.DMatrix(X.compute()) + m.set_info(label=y.compute(), weight=w.compute(), base_margin=margin.compute()) + single = booster.predict(m) # Make sure the ordering is correct. + assert np.all(prediction == single) + + def run_aft_survival(client, dmatrix_t): # survival doesn't handle empty dataset well. df = dd.read_csv(os.path.join(tm.PROJECT_ROOT, 'demo', 'data',