[dask] Add scheduler address to dask config. (#7581)

- Add user configuration.
- Bring back to the logic of using scheduler address from dask.  This was removed when we were trying to support GKE, now we bring it back and let xgboost try it if direct guess or host IP from user config failed.
This commit is contained in:
Jiaming Yuan
2022-01-22 01:56:32 +08:00
committed by GitHub
parent 5ddd4a9d06
commit ef4dae4c0e
6 changed files with 136 additions and 24 deletions

View File

@@ -30,6 +30,7 @@ if tm.no_dask()['condition']:
pytest.skip(msg=tm.no_dask()['reason'], allow_module_level=True)
from distributed import LocalCluster, Client
import dask
import dask.dataframe as dd
import dask.array as da
from xgboost.dask import DaskDMatrix
@@ -1219,6 +1220,10 @@ class TestWithDask:
os.remove(before_fname)
os.remove(after_fname)
with dask.config.set({'xgboost.foo': "bar"}):
with pytest.raises(ValueError):
xgb.dask.train(client, {}, dtrain, num_boost_round=4)
def run_updater_test(
self,
client: "Client",
@@ -1318,7 +1323,8 @@ class TestWithDask:
with Client(cluster) as client:
workers = _get_client_workers(client)
rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), client)
xgb.dask._get_rabit_args, len(workers), None, client
)
futures = client.map(runit,
workers,
pure=False,
@@ -1446,7 +1452,9 @@ class TestWithDask:
n_partitions = X.npartitions
m = xgb.dask.DaskDMatrix(client, X, y)
workers = _get_client_workers(client)
rabit_args = client.sync(xgb.dask._get_rabit_args, len(workers), client)
rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), None, client
)
n_workers = len(workers)
def worker_fn(worker_addr: str, data_ref: Dict) -> None: