[dask] Accept other inputs for prediction. (#5428)
* Returns a series when input is dataframe. * Merge assert client.
This commit is contained in:
@@ -57,7 +57,13 @@ def test_from_dask_dataframe():
|
||||
xgb.dask.train(
|
||||
client, {}, dtrain, num_boost_round=2, evals_result={})
|
||||
# force prediction to be computed
|
||||
prediction = prediction.compute()
|
||||
from_dmatrix = prediction.compute()
|
||||
|
||||
prediction = xgb.dask.predict(client, model=booster, data=X)
|
||||
from_df = prediction.compute()
|
||||
|
||||
assert isinstance(prediction, dd.Series)
|
||||
assert np.all(from_dmatrix == from_df.to_numpy())
|
||||
|
||||
|
||||
def test_from_dask_array():
|
||||
@@ -84,6 +90,12 @@ def test_from_dask_array():
|
||||
config = json.loads(booster.save_config())
|
||||
assert int(config['learner']['generic_param']['nthread']) == 5
|
||||
|
||||
from_arr = xgb.dask.predict(
|
||||
client, model=booster, data=X)
|
||||
|
||||
assert isinstance(from_arr, da.Array)
|
||||
assert np.all(single_node_predt == from_arr.compute())
|
||||
|
||||
|
||||
def test_dask_regressor():
|
||||
with LocalCluster(n_workers=5) as cluster:
|
||||
|
||||
Reference in New Issue
Block a user