Allow kwargs in dask predict (#6117)
This commit is contained in:
@@ -688,8 +688,8 @@ async def _direct_predict_impl(client, data, predict_fn):
|
||||
|
||||
|
||||
# pylint: disable=too-many-statements
|
||||
async def _predict_async(client: Client, model, data, *args,
|
||||
missing=numpy.nan):
|
||||
async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwargs):
|
||||
|
||||
if isinstance(model, Booster):
|
||||
booster = model
|
||||
elif isinstance(model, dict):
|
||||
@@ -704,7 +704,7 @@ async def _predict_async(client: Client, model, data, *args,
|
||||
worker = distributed_get_worker()
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
|
||||
predt = booster.predict(m, *args, validate_features=False)
|
||||
predt = booster.predict(m, validate_features=False, **kwargs)
|
||||
if is_df:
|
||||
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
|
||||
import cudf # pylint: disable=import-error
|
||||
@@ -737,7 +737,7 @@ async def _predict_async(client: Client, model, data, *args,
|
||||
missing=missing, nthread=worker.nthreads)
|
||||
predt = booster.predict(data=local_x,
|
||||
validate_features=local_x.num_row() != 0,
|
||||
*args)
|
||||
**kwargs)
|
||||
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
|
||||
ret = ((delayed(predt), columns), order)
|
||||
predictions.append(ret)
|
||||
@@ -784,7 +784,7 @@ async def _predict_async(client: Client, model, data, *args,
|
||||
return predictions
|
||||
|
||||
|
||||
def predict(client, model, data, *args, missing=numpy.nan):
|
||||
def predict(client, model, data, missing=numpy.nan, **kwargs):
|
||||
'''Run prediction with a trained booster.
|
||||
|
||||
.. note::
|
||||
@@ -813,8 +813,8 @@ def predict(client, model, data, *args, missing=numpy.nan):
|
||||
'''
|
||||
_assert_dask_support()
|
||||
client = _xgb_get_client(client)
|
||||
return client.sync(_predict_async, client, model, data, *args,
|
||||
missing=missing)
|
||||
return client.sync(_predict_async, client, model, data,
|
||||
missing=missing, **kwargs)
|
||||
|
||||
|
||||
async def _inplace_predict_async(client, model, data,
|
||||
|
||||
Reference in New Issue
Block a user