[dask] Disable gblinear and dart. (#6665)

This commit is contained in:
Jiaming Yuan 2021-02-04 09:13:09 +08:00 committed by GitHub
parent 9d62b14591
commit 72892cc80d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 0 deletions

View File

@ -802,6 +802,11 @@ async def _train_async(
workers = list(_get_workers_from_data(dtrain, evals))
_rabit_args = await _get_rabit_args(len(workers), client)
if params.get("booster", None) is not None and params["booster"] != "gbtree":
raise NotImplementedError(
f"booster `{params['booster']}` is not yet supported for dask."
)
def dispatched_train(
worker_addr: str,
rabit_args: List[bytes],

View File

@ -1167,6 +1167,19 @@ class TestWithDask:
np.testing.assert_allclose(predt_0.compute(), predt_3)
def test_unsupported_features(client: "Client"):
X, y, _ = generate_array()
# gblinear doesn't support distributed training.
with pytest.raises(NotImplementedError, match="gblinear"):
xgb.dask.train(
client, {"booster": "gblinear"}, xgb.dask.DaskDMatrix(client, X, y)
)
# dart prediction is not thread safe, running predict with each partition will have
# race.
with pytest.raises(NotImplementedError, match="dart"):
xgb.dask.train(client, {"booster": "dart"}, xgb.dask.DaskDMatrix(client, X, y))
class TestDaskCallbacks:
@pytest.mark.skipif(**tm.no_sklearn())
def test_early_stopping(self, client: "Client") -> None: