[dask] Set dataframe index in predict. (#6944)
This commit is contained in:
@@ -62,19 +62,17 @@ def run_with_dask_dataframe(DMatrixT: Type, client: Client) -> None:
|
||||
assert isinstance(out['booster'], dxgb.Booster)
|
||||
assert len(out['history']['X']['rmse']) == 4
|
||||
|
||||
predictions = dxgb.predict(client, out, dtrain).compute()
|
||||
assert isinstance(predictions, np.ndarray)
|
||||
predictions = dxgb.predict(client, out, dtrain)
|
||||
assert isinstance(predictions.compute(), np.ndarray)
|
||||
|
||||
series_predictions = dxgb.inplace_predict(client, out, X)
|
||||
assert isinstance(series_predictions, dd.Series)
|
||||
series_predictions = series_predictions.compute()
|
||||
|
||||
single_node = out['booster'].predict(
|
||||
xgboost.DMatrix(X.compute()))
|
||||
single_node = out['booster'].predict(xgboost.DMatrix(X.compute()))
|
||||
|
||||
cp.testing.assert_allclose(single_node, predictions)
|
||||
cp.testing.assert_allclose(single_node, predictions.compute())
|
||||
np.testing.assert_allclose(single_node,
|
||||
series_predictions.to_array())
|
||||
series_predictions.compute().to_array())
|
||||
|
||||
predt = dxgb.predict(client, out, X)
|
||||
assert isinstance(predt, dd.Series)
|
||||
@@ -92,6 +90,13 @@ def run_with_dask_dataframe(DMatrixT: Type, client: Client) -> None:
|
||||
cp.testing.assert_allclose(
|
||||
predt.values.compute(), single_node)
|
||||
|
||||
# Make sure the output can be integrated back to original dataframe
|
||||
X["predict"] = predictions
|
||||
X["inplace_predict"] = series_predictions
|
||||
|
||||
has_null = X.isnull().values.any().compute()
|
||||
assert bool(has_null) is False
|
||||
|
||||
|
||||
def run_with_dask_array(DMatrixT: Type, client: Client) -> None:
|
||||
import cupy as cp
|
||||
|
||||
Reference in New Issue
Block a user