[dask] prediction with categorical data. (#7708)

This commit is contained in:
Jiaming Yuan
2022-03-10 00:21:48 +08:00
committed by GitHub
parent 68b6d6bbe2
commit a62a3d991d
4 changed files with 40 additions and 19 deletions

View File

@@ -288,10 +288,23 @@ def run_categorical(client: "Client", tree_method: str, X, X_onehot, y) -> None:
reg.fit(X, y, eval_set=[(X, y)])
assert tm.non_increasing(reg.evals_result()["validation_0"]["rmse"])
booster = reg.get_booster()
predt = xgb.dask.predict(client, booster, X).compute().values
inpredt = xgb.dask.inplace_predict(client, booster, X).compute().values
if hasattr(predt, "get"):
predt = predt.get()
if hasattr(inpredt, "get"):
inpredt = inpredt.get()
np.testing.assert_allclose(predt, inpredt)
def test_categorical(client: "Client") -> None:
X, y = make_categorical(client, 10000, 30, 13)
X_onehot, _ = make_categorical(client, 10000, 30, 13, True)
run_categorical(client, "approx", X, X_onehot, y)
run_categorical(client, "hist", X, X_onehot, y)
def test_dask_predict_shape_infer(client: "Client") -> None: