[dask] Disable gblinear and dart. (#6665)
This commit is contained in:
parent
9d62b14591
commit
72892cc80d
@ -802,6 +802,11 @@ async def _train_async(
|
||||
workers = list(_get_workers_from_data(dtrain, evals))
|
||||
_rabit_args = await _get_rabit_args(len(workers), client)
|
||||
|
||||
if params.get("booster", None) is not None and params["booster"] != "gbtree":
|
||||
raise NotImplementedError(
|
||||
f"booster `{params['booster']}` is not yet supported for dask."
|
||||
)
|
||||
|
||||
def dispatched_train(
|
||||
worker_addr: str,
|
||||
rabit_args: List[bytes],
|
||||
|
||||
@ -1167,6 +1167,19 @@ class TestWithDask:
|
||||
np.testing.assert_allclose(predt_0.compute(), predt_3)
|
||||
|
||||
|
||||
def test_unsupported_features(client: "Client"):
|
||||
X, y, _ = generate_array()
|
||||
# gblinear doesn't support distributed training.
|
||||
with pytest.raises(NotImplementedError, match="gblinear"):
|
||||
xgb.dask.train(
|
||||
client, {"booster": "gblinear"}, xgb.dask.DaskDMatrix(client, X, y)
|
||||
)
|
||||
# dart prediction is not thread safe, running predict with each partition will have
|
||||
# race.
|
||||
with pytest.raises(NotImplementedError, match="dart"):
|
||||
xgb.dask.train(client, {"booster": "dart"}, xgb.dask.DaskDMatrix(client, X, y))
|
||||
|
||||
|
||||
class TestDaskCallbacks:
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_early_stopping(self, client: "Client") -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user