[dask] Accept other inputs for prediction. (#5428)
* Returns a series when input is dataframe. * Merge assert client.
This commit is contained in:
parent
8ca06ab329
commit
760d5d0c3c
@ -103,6 +103,9 @@ def concat(value):
|
||||
|
||||
def _xgb_get_client(client):
|
||||
'''Simple wrapper around testing None.'''
|
||||
if not isinstance(client, (type(get_client()), type(None))):
|
||||
raise TypeError(
|
||||
_expect([type(get_client()), type(None)], type(client)))
|
||||
ret = get_client() if client is None else client
|
||||
return ret
|
||||
|
||||
@ -112,12 +115,6 @@ def _get_client_workers(client):
|
||||
return workers
|
||||
|
||||
|
||||
def _assert_client(client):
|
||||
if not isinstance(client, (type(get_client()), type(None))):
|
||||
raise TypeError(
|
||||
_expect([type(get_client()), type(None)], type(client)))
|
||||
|
||||
|
||||
class DaskDMatrix:
|
||||
# pylint: disable=missing-docstring, too-many-instance-attributes
|
||||
'''DMatrix holding on references to Dask DataFrame or Dask Array. Constructing
|
||||
@ -155,7 +152,7 @@ class DaskDMatrix:
|
||||
feature_names=None,
|
||||
feature_types=None):
|
||||
_assert_dask_support()
|
||||
_assert_client(client)
|
||||
client = _xgb_get_client(client)
|
||||
|
||||
self.feature_names = feature_names
|
||||
self.feature_types = feature_types
|
||||
@ -177,7 +174,6 @@ class DaskDMatrix:
|
||||
self.has_label = label is not None
|
||||
self.has_weights = weight is not None
|
||||
|
||||
client = _xgb_get_client(client)
|
||||
client.sync(self.map_local_data, client, data, label, weight)
|
||||
|
||||
async def map_local_data(self, client, data, label=None, weights=None):
|
||||
@ -391,13 +387,12 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
||||
|
||||
'''
|
||||
_assert_dask_support()
|
||||
_assert_client(client)
|
||||
client = _xgb_get_client(client)
|
||||
if 'evals_result' in kwargs.keys():
|
||||
raise ValueError(
|
||||
'evals_result is not supported in dask interface.',
|
||||
'The evaluation history is returned as result of training.')
|
||||
|
||||
client = _xgb_get_client(client)
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
|
||||
rabit_args = _get_rabit_args(workers, client)
|
||||
@ -452,7 +447,7 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
||||
return list(filter(lambda ret: ret is not None, results))[0]
|
||||
|
||||
|
||||
def predict(client, model, data, *args):
|
||||
def predict(client, model, data, *args, missing=numpy.nan):
|
||||
'''Run prediction with a trained booster.
|
||||
|
||||
.. note::
|
||||
@ -466,32 +461,55 @@ def predict(client, model, data, *args):
|
||||
returned from dask if it's set to None.
|
||||
model: A Booster or a dictionary returned by `xgboost.dask.train`.
|
||||
The trained model.
|
||||
data: DaskDMatrix
|
||||
data: DaskDMatrix/dask.dataframe.DataFrame/dask.array.Array
|
||||
Input data used for prediction.
|
||||
missing: float
|
||||
Used when input data is not DaskDMatrix. Specify the value
|
||||
considered as missing.
|
||||
|
||||
Returns
|
||||
-------
|
||||
prediction: dask.array.Array
|
||||
prediction: dask.array.Array/dask.dataframe.Series
|
||||
|
||||
'''
|
||||
_assert_dask_support()
|
||||
_assert_client(client)
|
||||
client = _xgb_get_client(client)
|
||||
if isinstance(model, Booster):
|
||||
booster = model
|
||||
elif isinstance(model, dict):
|
||||
booster = model['booster']
|
||||
else:
|
||||
raise TypeError(_expect([Booster, dict], type(model)))
|
||||
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
|
||||
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame],
|
||||
type(data)))
|
||||
|
||||
if not isinstance(data, DaskDMatrix):
|
||||
raise TypeError(_expect([DaskDMatrix], type(data)))
|
||||
def mapped_predict(partition, is_df):
|
||||
worker = distributed_get_worker()
|
||||
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
|
||||
predt = booster.predict(m, *args, validate_features=False)
|
||||
if is_df:
|
||||
predt = DataFrame(predt, columns=['prediction'])
|
||||
return predt
|
||||
|
||||
if isinstance(data, da.Array):
|
||||
predictions = client.submit(
|
||||
da.map_blocks,
|
||||
mapped_predict, data, False, drop_axis=1,
|
||||
dtype=numpy.float32
|
||||
).result()
|
||||
return predictions
|
||||
if isinstance(data, dd.DataFrame):
|
||||
import dask
|
||||
predictions = client.submit(
|
||||
dd.map_partitions,
|
||||
mapped_predict, data, True,
|
||||
meta=dask.dataframe.utils.make_meta({'prediction': 'f4'})
|
||||
).result()
|
||||
return predictions.iloc[:, 0]
|
||||
|
||||
# Prediction on dask DMatrix.
|
||||
worker_map = data.worker_map
|
||||
client = _xgb_get_client(client)
|
||||
|
||||
missing = data.missing
|
||||
feature_names = data.feature_names
|
||||
feature_types = data.feature_types
|
||||
|
||||
def dispatched_predict(worker_id):
|
||||
'''Perform prediction on each worker.'''
|
||||
@ -502,9 +520,9 @@ def predict(client, model, data, *args):
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
for part, order in list_of_parts:
|
||||
local_x = DMatrix(part,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
missing=missing,
|
||||
feature_names=data.feature_names,
|
||||
feature_types=data.feature_types,
|
||||
missing=data.missing,
|
||||
nthread=worker.nthreads)
|
||||
predt = booster.predict(data=local_x,
|
||||
validate_features=local_x.num_row() != 0,
|
||||
@ -520,8 +538,7 @@ def predict(client, model, data, *args):
|
||||
list_of_parts = data.get_worker_x_ordered(worker)
|
||||
shapes = []
|
||||
for part, order in list_of_parts:
|
||||
s = part.shape
|
||||
shapes.append((s, order))
|
||||
shapes.append((part.shape, order))
|
||||
return shapes
|
||||
|
||||
def map_function(func):
|
||||
|
||||
@ -57,7 +57,13 @@ def test_from_dask_dataframe():
|
||||
xgb.dask.train(
|
||||
client, {}, dtrain, num_boost_round=2, evals_result={})
|
||||
# force prediction to be computed
|
||||
prediction = prediction.compute()
|
||||
from_dmatrix = prediction.compute()
|
||||
|
||||
prediction = xgb.dask.predict(client, model=booster, data=X)
|
||||
from_df = prediction.compute()
|
||||
|
||||
assert isinstance(prediction, dd.Series)
|
||||
assert np.all(from_dmatrix == from_df.to_numpy())
|
||||
|
||||
|
||||
def test_from_dask_array():
|
||||
@ -84,6 +90,12 @@ def test_from_dask_array():
|
||||
config = json.loads(booster.save_config())
|
||||
assert int(config['learner']['generic_param']['nthread']) == 5
|
||||
|
||||
from_arr = xgb.dask.predict(
|
||||
client, model=booster, data=X)
|
||||
|
||||
assert isinstance(from_arr, da.Array)
|
||||
assert np.all(single_node_predt == from_arr.compute())
|
||||
|
||||
|
||||
def test_dask_regressor():
|
||||
with LocalCluster(n_workers=5) as cluster:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user