[dask] Accept other inputs for prediction. (#5428)

* Returns a series when input is dataframe.

* Merge assert client.
This commit is contained in:
Jiaming Yuan
2020-03-19 17:05:55 +08:00
committed by GitHub
parent 8ca06ab329
commit 760d5d0c3c
2 changed files with 56 additions and 27 deletions

View File

@@ -57,7 +57,13 @@ def test_from_dask_dataframe():
xgb.dask.train(
client, {}, dtrain, num_boost_round=2, evals_result={})
# force prediction to be computed
prediction = prediction.compute()
from_dmatrix = prediction.compute()
prediction = xgb.dask.predict(client, model=booster, data=X)
from_df = prediction.compute()
assert isinstance(prediction, dd.Series)
assert np.all(from_dmatrix == from_df.to_numpy())
def test_from_dask_array():
@@ -84,6 +90,12 @@ def test_from_dask_array():
config = json.loads(booster.save_config())
assert int(config['learner']['generic_param']['nthread']) == 5
from_arr = xgb.dask.predict(
client, model=booster, data=X)
assert isinstance(from_arr, da.Array)
assert np.all(single_node_predt == from_arr.compute())
def test_dask_regressor():
with LocalCluster(n_workers=5) as cluster: