[dask] Fix prediction on DaskDMatrix with multiple meta data. (#6333)
* Unify the meta handling methods.
This commit is contained in:
parent
5a7b3592ed
commit
7756192906
@ -18,6 +18,7 @@ import logging
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
@ -300,8 +301,13 @@ class DaskDMatrix:
|
|||||||
append_meta(margin_parts, 'base_margin')
|
append_meta(margin_parts, 'base_margin')
|
||||||
append_meta(ll_parts, 'label_lower_bound')
|
append_meta(ll_parts, 'label_lower_bound')
|
||||||
append_meta(lu_parts, 'label_upper_bound')
|
append_meta(lu_parts, 'label_upper_bound')
|
||||||
|
# At this point, `parts` looks like:
|
||||||
|
# [(x0, x1, ..), (y0, y1, ..), ..] in delayed form
|
||||||
|
|
||||||
|
# delay the zipped result
|
||||||
parts = list(map(dask.delayed, zip(*parts)))
|
parts = list(map(dask.delayed, zip(*parts)))
|
||||||
|
# At this point, the mental model should look like:
|
||||||
|
# [(x0, y0, ..), (x1, y1, ..), ..] in delayed form
|
||||||
|
|
||||||
parts = client.compute(parts)
|
parts = client.compute(parts)
|
||||||
await distributed.wait(parts) # async wait for parts to be computed
|
await distributed.wait(parts) # async wait for parts to be computed
|
||||||
@ -309,6 +315,7 @@ class DaskDMatrix:
|
|||||||
for part in parts:
|
for part in parts:
|
||||||
assert part.status == 'finished'
|
assert part.status == 'finished'
|
||||||
|
|
||||||
|
# Preserving the partition order for prediction.
|
||||||
self.partition_order = {}
|
self.partition_order = {}
|
||||||
for i, part in enumerate(parts):
|
for i, part in enumerate(parts):
|
||||||
self.partition_order[part.key] = i
|
self.partition_order[part.key] = i
|
||||||
@ -339,59 +346,55 @@ class DaskDMatrix:
|
|||||||
'is_quantile': self.is_quantile}
|
'is_quantile': self.is_quantile}
|
||||||
|
|
||||||
|
|
||||||
def _get_worker_parts_ordered(has_base_margin, worker_map, partition_order,
|
def _get_worker_parts_ordered(meta_names, worker_map, partition_order, worker):
|
||||||
worker):
|
list_of_parts: List[tuple] = worker_map[worker.address]
|
||||||
list_of_parts = worker_map[worker.address]
|
# List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved.
|
||||||
client = distributed.get_client()
|
assert isinstance(list_of_parts, list)
|
||||||
|
with distributed.worker_client() as client:
|
||||||
list_of_parts_value = client.gather(list_of_parts)
|
list_of_parts_value = client.gather(list_of_parts)
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
for i, part in enumerate(list_of_parts):
|
for i, part in enumerate(list_of_parts):
|
||||||
data = list_of_parts_value[i][0]
|
data = list_of_parts_value[i][0]
|
||||||
if has_base_margin:
|
|
||||||
base_margin = list_of_parts_value[i][1]
|
|
||||||
else:
|
|
||||||
base_margin = None
|
|
||||||
result.append((data, base_margin, partition_order[part.key]))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _get_worker_parts(worker_map, meta_names, worker):
|
|
||||||
'''Get mapped parts of data in each worker from DaskDMatrix.'''
|
|
||||||
list_of_parts = worker_map[worker.address]
|
|
||||||
assert list_of_parts, 'data in ' + worker.address + ' was moved.'
|
|
||||||
assert isinstance(list_of_parts, list)
|
|
||||||
|
|
||||||
# `_get_worker_parts` is launched inside worker. In dask side
|
|
||||||
# this should be equal to `worker._get_client`.
|
|
||||||
client = distributed.get_client()
|
|
||||||
list_of_parts = client.gather(list_of_parts)
|
|
||||||
data = None
|
|
||||||
labels = None
|
labels = None
|
||||||
weights = None
|
weights = None
|
||||||
base_margin = None
|
base_margin = None
|
||||||
label_lower_bound = None
|
label_lower_bound = None
|
||||||
label_upper_bound = None
|
label_upper_bound = None
|
||||||
|
# Iterate through all possible meta info, brings small overhead as in xgboost
|
||||||
|
# there are constant number of meta info available.
|
||||||
|
for j, blob in enumerate(list_of_parts_value[i][1:]):
|
||||||
|
if meta_names[j] == 'labels':
|
||||||
|
labels = blob
|
||||||
|
elif meta_names[j] == 'weights':
|
||||||
|
weights = blob
|
||||||
|
elif meta_names[j] == 'base_margin':
|
||||||
|
base_margin = blob
|
||||||
|
elif meta_names[j] == 'label_lower_bound':
|
||||||
|
label_lower_bound = blob
|
||||||
|
elif meta_names[j] == 'label_upper_bound':
|
||||||
|
label_upper_bound = blob
|
||||||
|
else:
|
||||||
|
raise ValueError('Unknown metainfo:', meta_names[j])
|
||||||
|
|
||||||
local_data = list(zip(*list_of_parts))
|
if partition_order:
|
||||||
data = local_data[0]
|
result.append((data, labels, weights, base_margin, label_lower_bound,
|
||||||
|
label_upper_bound, partition_order[part.key]))
|
||||||
|
else:
|
||||||
|
result.append((data, labels, weights, base_margin, label_lower_bound,
|
||||||
|
label_upper_bound))
|
||||||
|
return result
|
||||||
|
|
||||||
for i, part in enumerate(local_data[1:]):
|
|
||||||
if meta_names[i] == 'labels':
|
|
||||||
labels = part
|
|
||||||
if meta_names[i] == 'weights':
|
|
||||||
weights = part
|
|
||||||
if meta_names[i] == 'base_margin':
|
|
||||||
base_margin = part
|
|
||||||
if meta_names[i] == 'label_lower_bound':
|
|
||||||
label_lower_bound = part
|
|
||||||
if meta_names[i] == 'label_upper_bound':
|
|
||||||
label_upper_bound = part
|
|
||||||
|
|
||||||
return (data, labels, weights, base_margin, label_lower_bound,
|
def _unzip(list_of_parts):
|
||||||
label_upper_bound)
|
return list(zip(*list_of_parts))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_worker_parts(worker_map, meta_names, worker):
|
||||||
|
partitions = _get_worker_parts_ordered(meta_names, worker_map, None, worker)
|
||||||
|
partitions = _unzip(partitions)
|
||||||
|
return partitions
|
||||||
|
|
||||||
|
|
||||||
class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
||||||
@ -585,9 +588,9 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing,
|
|||||||
return d
|
return d
|
||||||
|
|
||||||
def concat_or_none(data):
|
def concat_or_none(data):
|
||||||
if data is not None:
|
if all([part is None for part in data]):
|
||||||
|
return None
|
||||||
return concat(data)
|
return concat(data)
|
||||||
return data
|
|
||||||
|
|
||||||
(data, labels, weights, base_margin,
|
(data, labels, weights, base_margin,
|
||||||
label_lower_bound, label_upper_bound) = _get_worker_parts(
|
label_lower_bound, label_upper_bound) = _get_worker_parts(
|
||||||
@ -795,7 +798,7 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
|
|||||||
feature_names = data.feature_names
|
feature_names = data.feature_names
|
||||||
feature_types = data.feature_types
|
feature_types = data.feature_types
|
||||||
missing = data.missing
|
missing = data.missing
|
||||||
has_margin = "base_margin" in data.meta_names
|
meta_names = data.meta_names
|
||||||
|
|
||||||
def dispatched_predict(worker_id):
|
def dispatched_predict(worker_id):
|
||||||
'''Perform prediction on each worker.'''
|
'''Perform prediction on each worker.'''
|
||||||
@ -803,10 +806,11 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
|
|||||||
|
|
||||||
worker = distributed.get_worker()
|
worker = distributed.get_worker()
|
||||||
list_of_parts = _get_worker_parts_ordered(
|
list_of_parts = _get_worker_parts_ordered(
|
||||||
has_margin, worker_map, partition_order, worker)
|
meta_names, worker_map, partition_order, worker)
|
||||||
predictions = []
|
predictions = []
|
||||||
booster.set_param({'nthread': worker.nthreads})
|
booster.set_param({'nthread': worker.nthreads})
|
||||||
for data, base_margin, order in list_of_parts:
|
for parts in list_of_parts:
|
||||||
|
(data, _, _, base_margin, _, _, order) = parts
|
||||||
local_part = DMatrix(
|
local_part = DMatrix(
|
||||||
data,
|
data,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
@ -829,12 +833,15 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
|
|||||||
LOGGER.info('Get shape on %d', worker_id)
|
LOGGER.info('Get shape on %d', worker_id)
|
||||||
worker = distributed.get_worker()
|
worker = distributed.get_worker()
|
||||||
list_of_parts = _get_worker_parts_ordered(
|
list_of_parts = _get_worker_parts_ordered(
|
||||||
False,
|
meta_names,
|
||||||
worker_map,
|
worker_map,
|
||||||
partition_order,
|
partition_order,
|
||||||
worker
|
worker
|
||||||
)
|
)
|
||||||
shapes = [(part.shape, order) for part, _, order in list_of_parts]
|
shapes = []
|
||||||
|
for parts in list_of_parts:
|
||||||
|
(data, _, _, _, _, _, order) = parts
|
||||||
|
shapes.append((data.shape, order))
|
||||||
return shapes
|
return shapes
|
||||||
|
|
||||||
async def map_function(func):
|
async def map_function(func):
|
||||||
@ -974,8 +981,7 @@ def inplace_predict(client, model, data,
|
|||||||
missing=missing)
|
missing=missing)
|
||||||
|
|
||||||
|
|
||||||
async def _evaluation_matrices(client, validation_set,
|
async def _evaluation_matrices(client, validation_set, sample_weight, missing):
|
||||||
sample_weight, missing):
|
|
||||||
'''
|
'''
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -998,8 +1004,7 @@ async def _evaluation_matrices(client, validation_set,
|
|||||||
if validation_set is not None:
|
if validation_set is not None:
|
||||||
assert isinstance(validation_set, list)
|
assert isinstance(validation_set, list)
|
||||||
for i, e in enumerate(validation_set):
|
for i, e in enumerate(validation_set):
|
||||||
w = (sample_weight[i]
|
w = (sample_weight[i] if sample_weight is not None else None)
|
||||||
if sample_weight is not None else None)
|
|
||||||
dmat = await DaskDMatrix(client=client, data=e[0], label=e[1],
|
dmat = await DaskDMatrix(client=client, data=e[0], label=e[1],
|
||||||
weight=w, missing=missing)
|
weight=w, missing=missing)
|
||||||
evals.append((dmat, 'validation_{}'.format(i)))
|
evals.append((dmat, 'validation_{}'.format(i)))
|
||||||
|
|||||||
@ -566,6 +566,28 @@ def test_predict():
|
|||||||
assert shap.shape[1] == kCols + 1
|
assert shap.shape[1] == kCols + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_predict_with_meta(client):
|
||||||
|
X, y, w = generate_array(with_weights=True)
|
||||||
|
partition_size = 20
|
||||||
|
margin = da.random.random(kRows, partition_size) + 1e4
|
||||||
|
|
||||||
|
dtrain = DaskDMatrix(client, X, y, weight=w, base_margin=margin)
|
||||||
|
booster = xgb.dask.train(
|
||||||
|
client, {}, dtrain, num_boost_round=4)['booster']
|
||||||
|
|
||||||
|
prediction = xgb.dask.predict(client, model=booster, data=dtrain)
|
||||||
|
assert prediction.ndim == 1
|
||||||
|
assert prediction.shape[0] == kRows
|
||||||
|
|
||||||
|
prediction = client.compute(prediction).result()
|
||||||
|
assert np.all(prediction > 1e3)
|
||||||
|
|
||||||
|
m = xgb.DMatrix(X.compute())
|
||||||
|
m.set_info(label=y.compute(), weight=w.compute(), base_margin=margin.compute())
|
||||||
|
single = booster.predict(m) # Make sure the ordering is correct.
|
||||||
|
assert np.all(prediction == single)
|
||||||
|
|
||||||
|
|
||||||
def run_aft_survival(client, dmatrix_t):
|
def run_aft_survival(client, dmatrix_t):
|
||||||
# survival doesn't handle empty dataset well.
|
# survival doesn't handle empty dataset well.
|
||||||
df = dd.read_csv(os.path.join(tm.PROJECT_ROOT, 'demo', 'data',
|
df = dd.read_csv(os.path.join(tm.PROJECT_ROOT, 'demo', 'data',
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user