[dask] Accept other inputs for prediction. (#5428)
* Returns a series when input is dataframe. * Merge assert client.
This commit is contained in:
parent
8ca06ab329
commit
760d5d0c3c
@ -103,6 +103,9 @@ def concat(value):
|
|||||||
|
|
||||||
def _xgb_get_client(client):
|
def _xgb_get_client(client):
|
||||||
'''Simple wrapper around testing None.'''
|
'''Simple wrapper around testing None.'''
|
||||||
|
if not isinstance(client, (type(get_client()), type(None))):
|
||||||
|
raise TypeError(
|
||||||
|
_expect([type(get_client()), type(None)], type(client)))
|
||||||
ret = get_client() if client is None else client
|
ret = get_client() if client is None else client
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@ -112,12 +115,6 @@ def _get_client_workers(client):
|
|||||||
return workers
|
return workers
|
||||||
|
|
||||||
|
|
||||||
def _assert_client(client):
|
|
||||||
if not isinstance(client, (type(get_client()), type(None))):
|
|
||||||
raise TypeError(
|
|
||||||
_expect([type(get_client()), type(None)], type(client)))
|
|
||||||
|
|
||||||
|
|
||||||
class DaskDMatrix:
|
class DaskDMatrix:
|
||||||
# pylint: disable=missing-docstring, too-many-instance-attributes
|
# pylint: disable=missing-docstring, too-many-instance-attributes
|
||||||
'''DMatrix holding on references to Dask DataFrame or Dask Array. Constructing
|
'''DMatrix holding on references to Dask DataFrame or Dask Array. Constructing
|
||||||
@ -155,7 +152,7 @@ class DaskDMatrix:
|
|||||||
feature_names=None,
|
feature_names=None,
|
||||||
feature_types=None):
|
feature_types=None):
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
_assert_client(client)
|
client = _xgb_get_client(client)
|
||||||
|
|
||||||
self.feature_names = feature_names
|
self.feature_names = feature_names
|
||||||
self.feature_types = feature_types
|
self.feature_types = feature_types
|
||||||
@ -177,7 +174,6 @@ class DaskDMatrix:
|
|||||||
self.has_label = label is not None
|
self.has_label = label is not None
|
||||||
self.has_weights = weight is not None
|
self.has_weights = weight is not None
|
||||||
|
|
||||||
client = _xgb_get_client(client)
|
|
||||||
client.sync(self.map_local_data, client, data, label, weight)
|
client.sync(self.map_local_data, client, data, label, weight)
|
||||||
|
|
||||||
async def map_local_data(self, client, data, label=None, weights=None):
|
async def map_local_data(self, client, data, label=None, weights=None):
|
||||||
@ -391,13 +387,12 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
|||||||
|
|
||||||
'''
|
'''
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
_assert_client(client)
|
client = _xgb_get_client(client)
|
||||||
if 'evals_result' in kwargs.keys():
|
if 'evals_result' in kwargs.keys():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'evals_result is not supported in dask interface.',
|
'evals_result is not supported in dask interface.',
|
||||||
'The evaluation history is returned as result of training.')
|
'The evaluation history is returned as result of training.')
|
||||||
|
|
||||||
client = _xgb_get_client(client)
|
|
||||||
workers = list(_get_client_workers(client).keys())
|
workers = list(_get_client_workers(client).keys())
|
||||||
|
|
||||||
rabit_args = _get_rabit_args(workers, client)
|
rabit_args = _get_rabit_args(workers, client)
|
||||||
@ -452,7 +447,7 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
|||||||
return list(filter(lambda ret: ret is not None, results))[0]
|
return list(filter(lambda ret: ret is not None, results))[0]
|
||||||
|
|
||||||
|
|
||||||
def predict(client, model, data, *args):
|
def predict(client, model, data, *args, missing=numpy.nan):
|
||||||
'''Run prediction with a trained booster.
|
'''Run prediction with a trained booster.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
@ -466,32 +461,55 @@ def predict(client, model, data, *args):
|
|||||||
returned from dask if it's set to None.
|
returned from dask if it's set to None.
|
||||||
model: A Booster or a dictionary returned by `xgboost.dask.train`.
|
model: A Booster or a dictionary returned by `xgboost.dask.train`.
|
||||||
The trained model.
|
The trained model.
|
||||||
data: DaskDMatrix
|
data: DaskDMatrix/dask.dataframe.DataFrame/dask.array.Array
|
||||||
Input data used for prediction.
|
Input data used for prediction.
|
||||||
|
missing: float
|
||||||
|
Used when input data is not DaskDMatrix. Specify the value
|
||||||
|
considered as missing.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
prediction: dask.array.Array
|
prediction: dask.array.Array/dask.dataframe.Series
|
||||||
|
|
||||||
'''
|
'''
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
_assert_client(client)
|
client = _xgb_get_client(client)
|
||||||
if isinstance(model, Booster):
|
if isinstance(model, Booster):
|
||||||
booster = model
|
booster = model
|
||||||
elif isinstance(model, dict):
|
elif isinstance(model, dict):
|
||||||
booster = model['booster']
|
booster = model['booster']
|
||||||
else:
|
else:
|
||||||
raise TypeError(_expect([Booster, dict], type(model)))
|
raise TypeError(_expect([Booster, dict], type(model)))
|
||||||
|
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
|
||||||
|
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame],
|
||||||
|
type(data)))
|
||||||
|
|
||||||
if not isinstance(data, DaskDMatrix):
|
def mapped_predict(partition, is_df):
|
||||||
raise TypeError(_expect([DaskDMatrix], type(data)))
|
worker = distributed_get_worker()
|
||||||
|
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
|
||||||
|
predt = booster.predict(m, *args, validate_features=False)
|
||||||
|
if is_df:
|
||||||
|
predt = DataFrame(predt, columns=['prediction'])
|
||||||
|
return predt
|
||||||
|
|
||||||
|
if isinstance(data, da.Array):
|
||||||
|
predictions = client.submit(
|
||||||
|
da.map_blocks,
|
||||||
|
mapped_predict, data, False, drop_axis=1,
|
||||||
|
dtype=numpy.float32
|
||||||
|
).result()
|
||||||
|
return predictions
|
||||||
|
if isinstance(data, dd.DataFrame):
|
||||||
|
import dask
|
||||||
|
predictions = client.submit(
|
||||||
|
dd.map_partitions,
|
||||||
|
mapped_predict, data, True,
|
||||||
|
meta=dask.dataframe.utils.make_meta({'prediction': 'f4'})
|
||||||
|
).result()
|
||||||
|
return predictions.iloc[:, 0]
|
||||||
|
|
||||||
|
# Prediction on dask DMatrix.
|
||||||
worker_map = data.worker_map
|
worker_map = data.worker_map
|
||||||
client = _xgb_get_client(client)
|
|
||||||
|
|
||||||
missing = data.missing
|
|
||||||
feature_names = data.feature_names
|
|
||||||
feature_types = data.feature_types
|
|
||||||
|
|
||||||
def dispatched_predict(worker_id):
|
def dispatched_predict(worker_id):
|
||||||
'''Perform prediction on each worker.'''
|
'''Perform prediction on each worker.'''
|
||||||
@ -502,9 +520,9 @@ def predict(client, model, data, *args):
|
|||||||
booster.set_param({'nthread': worker.nthreads})
|
booster.set_param({'nthread': worker.nthreads})
|
||||||
for part, order in list_of_parts:
|
for part, order in list_of_parts:
|
||||||
local_x = DMatrix(part,
|
local_x = DMatrix(part,
|
||||||
feature_names=feature_names,
|
feature_names=data.feature_names,
|
||||||
feature_types=feature_types,
|
feature_types=data.feature_types,
|
||||||
missing=missing,
|
missing=data.missing,
|
||||||
nthread=worker.nthreads)
|
nthread=worker.nthreads)
|
||||||
predt = booster.predict(data=local_x,
|
predt = booster.predict(data=local_x,
|
||||||
validate_features=local_x.num_row() != 0,
|
validate_features=local_x.num_row() != 0,
|
||||||
@ -520,8 +538,7 @@ def predict(client, model, data, *args):
|
|||||||
list_of_parts = data.get_worker_x_ordered(worker)
|
list_of_parts = data.get_worker_x_ordered(worker)
|
||||||
shapes = []
|
shapes = []
|
||||||
for part, order in list_of_parts:
|
for part, order in list_of_parts:
|
||||||
s = part.shape
|
shapes.append((part.shape, order))
|
||||||
shapes.append((s, order))
|
|
||||||
return shapes
|
return shapes
|
||||||
|
|
||||||
def map_function(func):
|
def map_function(func):
|
||||||
|
|||||||
@ -57,7 +57,13 @@ def test_from_dask_dataframe():
|
|||||||
xgb.dask.train(
|
xgb.dask.train(
|
||||||
client, {}, dtrain, num_boost_round=2, evals_result={})
|
client, {}, dtrain, num_boost_round=2, evals_result={})
|
||||||
# force prediction to be computed
|
# force prediction to be computed
|
||||||
prediction = prediction.compute()
|
from_dmatrix = prediction.compute()
|
||||||
|
|
||||||
|
prediction = xgb.dask.predict(client, model=booster, data=X)
|
||||||
|
from_df = prediction.compute()
|
||||||
|
|
||||||
|
assert isinstance(prediction, dd.Series)
|
||||||
|
assert np.all(from_dmatrix == from_df.to_numpy())
|
||||||
|
|
||||||
|
|
||||||
def test_from_dask_array():
|
def test_from_dask_array():
|
||||||
@ -84,6 +90,12 @@ def test_from_dask_array():
|
|||||||
config = json.loads(booster.save_config())
|
config = json.loads(booster.save_config())
|
||||||
assert int(config['learner']['generic_param']['nthread']) == 5
|
assert int(config['learner']['generic_param']['nthread']) == 5
|
||||||
|
|
||||||
|
from_arr = xgb.dask.predict(
|
||||||
|
client, model=booster, data=X)
|
||||||
|
|
||||||
|
assert isinstance(from_arr, da.Array)
|
||||||
|
assert np.all(single_node_predt == from_arr.compute())
|
||||||
|
|
||||||
|
|
||||||
def test_dask_regressor():
|
def test_dask_regressor():
|
||||||
with LocalCluster(n_workers=5) as cluster:
|
with LocalCluster(n_workers=5) as cluster:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user