[dask] prediction with categorical data. (#7708)
This commit is contained in:
@@ -288,10 +288,23 @@ def run_categorical(client: "Client", tree_method: str, X, X_onehot, y) -> None:
|
||||
reg.fit(X, y, eval_set=[(X, y)])
|
||||
assert tm.non_increasing(reg.evals_result()["validation_0"]["rmse"])
|
||||
|
||||
booster = reg.get_booster()
|
||||
predt = xgb.dask.predict(client, booster, X).compute().values
|
||||
inpredt = xgb.dask.inplace_predict(client, booster, X).compute().values
|
||||
|
||||
if hasattr(predt, "get"):
|
||||
predt = predt.get()
|
||||
if hasattr(inpredt, "get"):
|
||||
inpredt = inpredt.get()
|
||||
|
||||
np.testing.assert_allclose(predt, inpredt)
|
||||
|
||||
|
||||
def test_categorical(client: "Client") -> None:
|
||||
X, y = make_categorical(client, 10000, 30, 13)
|
||||
X_onehot, _ = make_categorical(client, 10000, 30, 13, True)
|
||||
run_categorical(client, "approx", X, X_onehot, y)
|
||||
run_categorical(client, "hist", X, X_onehot, y)
|
||||
|
||||
|
||||
def test_dask_predict_shape_infer(client: "Client") -> None:
|
||||
|
||||
Reference in New Issue
Block a user