Allow kwargs in dask predict (#6117)
This commit is contained in:
parent
b5f52f0b1b
commit
47350f6acb
@ -688,8 +688,8 @@ async def _direct_predict_impl(client, data, predict_fn):
|
|||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-statements
|
# pylint: disable=too-many-statements
|
||||||
async def _predict_async(client: Client, model, data, *args,
|
async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwargs):
|
||||||
missing=numpy.nan):
|
|
||||||
if isinstance(model, Booster):
|
if isinstance(model, Booster):
|
||||||
booster = model
|
booster = model
|
||||||
elif isinstance(model, dict):
|
elif isinstance(model, dict):
|
||||||
@ -704,7 +704,7 @@ async def _predict_async(client: Client, model, data, *args,
|
|||||||
worker = distributed_get_worker()
|
worker = distributed_get_worker()
|
||||||
booster.set_param({'nthread': worker.nthreads})
|
booster.set_param({'nthread': worker.nthreads})
|
||||||
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
|
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
|
||||||
predt = booster.predict(m, *args, validate_features=False)
|
predt = booster.predict(m, validate_features=False, **kwargs)
|
||||||
if is_df:
|
if is_df:
|
||||||
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
|
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
|
||||||
import cudf # pylint: disable=import-error
|
import cudf # pylint: disable=import-error
|
||||||
@ -737,7 +737,7 @@ async def _predict_async(client: Client, model, data, *args,
|
|||||||
missing=missing, nthread=worker.nthreads)
|
missing=missing, nthread=worker.nthreads)
|
||||||
predt = booster.predict(data=local_x,
|
predt = booster.predict(data=local_x,
|
||||||
validate_features=local_x.num_row() != 0,
|
validate_features=local_x.num_row() != 0,
|
||||||
*args)
|
**kwargs)
|
||||||
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
|
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
|
||||||
ret = ((delayed(predt), columns), order)
|
ret = ((delayed(predt), columns), order)
|
||||||
predictions.append(ret)
|
predictions.append(ret)
|
||||||
@ -784,7 +784,7 @@ async def _predict_async(client: Client, model, data, *args,
|
|||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
def predict(client, model, data, *args, missing=numpy.nan):
|
def predict(client, model, data, missing=numpy.nan, **kwargs):
|
||||||
'''Run prediction with a trained booster.
|
'''Run prediction with a trained booster.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
@ -813,8 +813,8 @@ def predict(client, model, data, *args, missing=numpy.nan):
|
|||||||
'''
|
'''
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
client = _xgb_get_client(client)
|
client = _xgb_get_client(client)
|
||||||
return client.sync(_predict_async, client, model, data, *args,
|
return client.sync(_predict_async, client, model, data,
|
||||||
missing=missing)
|
missing=missing, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
async def _inplace_predict_async(client, model, data,
|
async def _inplace_predict_async(client, model, data,
|
||||||
|
|||||||
@ -276,7 +276,6 @@ def test_sklearn_grid_search():
|
|||||||
|
|
||||||
|
|
||||||
def run_empty_dmatrix_reg(client, parameters):
|
def run_empty_dmatrix_reg(client, parameters):
|
||||||
|
|
||||||
def _check_outputs(out, predictions):
|
def _check_outputs(out, predictions):
|
||||||
assert isinstance(out['booster'], xgb.dask.Booster)
|
assert isinstance(out['booster'], xgb.dask.Booster)
|
||||||
assert len(out['history']['validation']['rmse']) == 2
|
assert len(out['history']['validation']['rmse']) == 2
|
||||||
@ -447,7 +446,6 @@ async def run_dask_classifier_asyncio(scheduler_address):
|
|||||||
assert probas.shape[0] == kRows
|
assert probas.shape[0] == kRows
|
||||||
assert probas.shape[1] == 10
|
assert probas.shape[1] == 10
|
||||||
|
|
||||||
|
|
||||||
# Test with dataframe.
|
# Test with dataframe.
|
||||||
X_d = dd.from_dask_array(X)
|
X_d = dd.from_dask_array(X)
|
||||||
y_d = dd.from_dask_array(y)
|
y_d = dd.from_dask_array(y)
|
||||||
@ -472,6 +470,28 @@ def test_with_asyncio():
|
|||||||
asyncio.run(run_dask_classifier_asyncio(address))
|
asyncio.run(run_dask_classifier_asyncio(address))
|
||||||
|
|
||||||
|
|
||||||
|
def test_predict():
|
||||||
|
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
|
X, y = generate_array()
|
||||||
|
dtrain = DaskDMatrix(client, X, y)
|
||||||
|
booster = xgb.dask.train(
|
||||||
|
client, {}, dtrain, num_boost_round=2)['booster']
|
||||||
|
|
||||||
|
pred = xgb.dask.predict(client, model=booster, data=dtrain)
|
||||||
|
assert pred.ndim == 1
|
||||||
|
assert pred.shape[0] == kRows
|
||||||
|
|
||||||
|
margin = xgb.dask.predict(client, model=booster, data=dtrain, output_margin=True)
|
||||||
|
assert margin.ndim == 1
|
||||||
|
assert margin.shape[0] == kRows
|
||||||
|
|
||||||
|
shap = xgb.dask.predict(client, model=booster, data=dtrain, pred_contribs=True)
|
||||||
|
assert shap.ndim == 2
|
||||||
|
assert shap.shape[0] == kRows
|
||||||
|
assert shap.shape[1] == kCols + 1
|
||||||
|
|
||||||
|
|
||||||
class TestWithDask:
|
class TestWithDask:
|
||||||
def run_updater_test(self, client, params, num_rounds, dataset,
|
def run_updater_test(self, client, params, num_rounds, dataset,
|
||||||
tree_method):
|
tree_method):
|
||||||
@ -489,9 +509,9 @@ class TestWithDask:
|
|||||||
chunk = 128
|
chunk = 128
|
||||||
X = da.from_array(dataset.X,
|
X = da.from_array(dataset.X,
|
||||||
chunks=(chunk, dataset.X.shape[1]))
|
chunks=(chunk, dataset.X.shape[1]))
|
||||||
y = da.from_array(dataset.y, chunks=(chunk, ))
|
y = da.from_array(dataset.y, chunks=(chunk,))
|
||||||
if dataset.w is not None:
|
if dataset.w is not None:
|
||||||
w = da.from_array(dataset.w, chunks=(chunk, ))
|
w = da.from_array(dataset.w, chunks=(chunk,))
|
||||||
else:
|
else:
|
||||||
w = None
|
w = None
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user