[dask] Set dataframe index in predict. (#6944)

This commit is contained in:
Jiaming Yuan
2021-05-12 13:24:21 +08:00
committed by GitHub
parent 3e7e426b36
commit 05ac415780
3 changed files with 46 additions and 26 deletions

View File

@@ -100,6 +100,12 @@ def test_from_dask_dataframe() -> None:
np.testing.assert_allclose(series_predictions.compute().values,
from_dmatrix)
# Make sure the output can be integrated back to original dataframe
X["predict"] = prediction
X["inplace_predict"] = series_predictions
assert bool(X.isnull().values.any().compute()) is False
def test_from_dask_array() -> None:
with LocalCluster(n_workers=kWorkers, threads_per_worker=5) as cluster: