[dask] Set dataframe index in predict. (#6944)

This commit is contained in:
Jiaming Yuan
2021-05-12 13:24:21 +08:00
committed by GitHub
parent 3e7e426b36
commit 05ac415780
3 changed files with 46 additions and 26 deletions

View File

@@ -62,19 +62,17 @@ def run_with_dask_dataframe(DMatrixT: Type, client: Client) -> None:
assert isinstance(out['booster'], dxgb.Booster)
assert len(out['history']['X']['rmse']) == 4
predictions = dxgb.predict(client, out, dtrain).compute()
assert isinstance(predictions, np.ndarray)
predictions = dxgb.predict(client, out, dtrain)
assert isinstance(predictions.compute(), np.ndarray)
series_predictions = dxgb.inplace_predict(client, out, X)
assert isinstance(series_predictions, dd.Series)
series_predictions = series_predictions.compute()
single_node = out['booster'].predict(
xgboost.DMatrix(X.compute()))
single_node = out['booster'].predict(xgboost.DMatrix(X.compute()))
cp.testing.assert_allclose(single_node, predictions)
cp.testing.assert_allclose(single_node, predictions.compute())
np.testing.assert_allclose(single_node,
series_predictions.to_array())
series_predictions.compute().to_array())
predt = dxgb.predict(client, out, X)
assert isinstance(predt, dd.Series)
@@ -92,6 +90,13 @@ def run_with_dask_dataframe(DMatrixT: Type, client: Client) -> None:
cp.testing.assert_allclose(
predt.values.compute(), single_node)
# Make sure the output can be integrated back to original dataframe
X["predict"] = predictions
X["inplace_predict"] = series_predictions
has_null = X.isnull().values.any().compute()
assert bool(has_null) is False
def run_with_dask_array(DMatrixT: Type, client: Client) -> None:
import cupy as cp