[rabit] Improved connection handling. (#9531)

- Enable timeout.
- Report connection error from the system.
- Handle retry for both tracker connection and peer connection.
This commit is contained in:
Jiaming Yuan
2023-08-30 13:00:04 +08:00
committed by GitHub
parent 2462e22cd4
commit ccfc90e4c6
10 changed files with 463 additions and 130 deletions

View File

@@ -20,6 +20,18 @@ def test_rabit_tracker():
assert str(ret) == "test1234"
@pytest.mark.skipif(**tm.not_linux())
def test_socket_error():
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
tracker.start(1)
env = tracker.worker_envs()
env["DMLC_TRACKER_PORT"] = 0
env["DMLC_WORKER_CONNECT_RETRY"] = 1
with pytest.raises(ValueError, match="127.0.0.1:0\n.*refused"):
with xgb.collective.CommunicatorContext(**env):
pass
def run_rabit_ops(client, n_workers):
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
@@ -58,6 +70,32 @@ def test_rabit_ops():
run_rabit_ops(client, n_workers)
def run_broadcast(client):
from xgboost.dask import _get_dask_config, _get_rabit_args
workers = tm.get_client_workers(client)
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
def local_test(worker_id):
with collective.CommunicatorContext(**rabit_args):
res = collective.broadcast(17, 0)
return res
futures = client.map(local_test, range(len(workers)), workers=workers)
results = client.gather(futures)
np.testing.assert_allclose(np.array(results), 17)
@pytest.mark.skipif(**tm.no_dask())
def test_broadcast():
from distributed import Client, LocalCluster
n_workers = 3
with LocalCluster(n_workers=n_workers) as cluster:
with Client(cluster) as client:
run_broadcast(client)
@pytest.mark.skipif(**tm.no_ipv6())
@pytest.mark.skipif(**tm.no_dask())
def test_rabit_ops_ipv6():