Allow kwargs in dask predict (#6117)

This commit is contained in:
Rory Mitchell 2020-09-15 13:04:03 +12:00 committed by GitHub
parent b5f52f0b1b
commit 47350f6acb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 13 deletions

View File

@ -688,8 +688,8 @@ async def _direct_predict_impl(client, data, predict_fn):
# pylint: disable=too-many-statements # pylint: disable=too-many-statements
async def _predict_async(client: Client, model, data, *args, async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwargs):
missing=numpy.nan):
if isinstance(model, Booster): if isinstance(model, Booster):
booster = model booster = model
elif isinstance(model, dict): elif isinstance(model, dict):
@ -704,7 +704,7 @@ async def _predict_async(client: Client, model, data, *args,
worker = distributed_get_worker() worker = distributed_get_worker()
booster.set_param({'nthread': worker.nthreads}) booster.set_param({'nthread': worker.nthreads})
m = DMatrix(partition, missing=missing, nthread=worker.nthreads) m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
predt = booster.predict(m, *args, validate_features=False) predt = booster.predict(m, validate_features=False, **kwargs)
if is_df: if is_df:
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'): if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
import cudf # pylint: disable=import-error import cudf # pylint: disable=import-error
@ -737,7 +737,7 @@ async def _predict_async(client: Client, model, data, *args,
missing=missing, nthread=worker.nthreads) missing=missing, nthread=worker.nthreads)
predt = booster.predict(data=local_x, predt = booster.predict(data=local_x,
validate_features=local_x.num_row() != 0, validate_features=local_x.num_row() != 0,
*args) **kwargs)
columns = 1 if len(predt.shape) == 1 else predt.shape[1] columns = 1 if len(predt.shape) == 1 else predt.shape[1]
ret = ((delayed(predt), columns), order) ret = ((delayed(predt), columns), order)
predictions.append(ret) predictions.append(ret)
@ -784,7 +784,7 @@ async def _predict_async(client: Client, model, data, *args,
return predictions return predictions
def predict(client, model, data, *args, missing=numpy.nan): def predict(client, model, data, missing=numpy.nan, **kwargs):
'''Run prediction with a trained booster. '''Run prediction with a trained booster.
.. note:: .. note::
@ -813,8 +813,8 @@ def predict(client, model, data, *args, missing=numpy.nan):
''' '''
_assert_dask_support() _assert_dask_support()
client = _xgb_get_client(client) client = _xgb_get_client(client)
return client.sync(_predict_async, client, model, data, *args, return client.sync(_predict_async, client, model, data,
missing=missing) missing=missing, **kwargs)
async def _inplace_predict_async(client, model, data, async def _inplace_predict_async(client, model, data,

View File

@ -276,7 +276,6 @@ def test_sklearn_grid_search():
def run_empty_dmatrix_reg(client, parameters): def run_empty_dmatrix_reg(client, parameters):
def _check_outputs(out, predictions): def _check_outputs(out, predictions):
assert isinstance(out['booster'], xgb.dask.Booster) assert isinstance(out['booster'], xgb.dask.Booster)
assert len(out['history']['validation']['rmse']) == 2 assert len(out['history']['validation']['rmse']) == 2
@ -447,7 +446,6 @@ async def run_dask_classifier_asyncio(scheduler_address):
assert probas.shape[0] == kRows assert probas.shape[0] == kRows
assert probas.shape[1] == 10 assert probas.shape[1] == 10
# Test with dataframe. # Test with dataframe.
X_d = dd.from_dask_array(X) X_d = dd.from_dask_array(X)
y_d = dd.from_dask_array(y) y_d = dd.from_dask_array(y)
@ -472,6 +470,28 @@ def test_with_asyncio():
asyncio.run(run_dask_classifier_asyncio(address)) asyncio.run(run_dask_classifier_asyncio(address))
def test_predict():
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
X, y = generate_array()
dtrain = DaskDMatrix(client, X, y)
booster = xgb.dask.train(
client, {}, dtrain, num_boost_round=2)['booster']
pred = xgb.dask.predict(client, model=booster, data=dtrain)
assert pred.ndim == 1
assert pred.shape[0] == kRows
margin = xgb.dask.predict(client, model=booster, data=dtrain, output_margin=True)
assert margin.ndim == 1
assert margin.shape[0] == kRows
shap = xgb.dask.predict(client, model=booster, data=dtrain, pred_contribs=True)
assert shap.ndim == 2
assert shap.shape[0] == kRows
assert shap.shape[1] == kCols + 1
class TestWithDask: class TestWithDask:
def run_updater_test(self, client, params, num_rounds, dataset, def run_updater_test(self, client, params, num_rounds, dataset,
tree_method): tree_method):
@ -489,9 +509,9 @@ class TestWithDask:
chunk = 128 chunk = 128
X = da.from_array(dataset.X, X = da.from_array(dataset.X,
chunks=(chunk, dataset.X.shape[1])) chunks=(chunk, dataset.X.shape[1]))
y = da.from_array(dataset.y, chunks=(chunk, )) y = da.from_array(dataset.y, chunks=(chunk,))
if dataset.w is not None: if dataset.w is not None:
w = da.from_array(dataset.w, chunks=(chunk, )) w = da.from_array(dataset.w, chunks=(chunk,))
else: else:
w = None w = None