Initial support for IPv6 (#8225)

- Merge rabit socket into XGBoost.
- Dask interface support.
- Add test to the socket.
This commit is contained in:
Jiaming Yuan
2022-09-21 18:06:50 +08:00
committed by GitHub
parent 7d43e74e71
commit b791446623
17 changed files with 924 additions and 595 deletions

View File

@@ -1,34 +1,37 @@
from xgboost import RabitTracker
import xgboost as xgb
import re
import sys
import numpy as np
import pytest
import testing as tm
import numpy as np
import sys
import re
import xgboost as xgb
from xgboost import RabitTracker, testing
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
def test_rabit_tracker():
tracker = RabitTracker(host_ip='127.0.0.1', n_workers=1)
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
tracker.start(1)
worker_env = tracker.worker_envs()
rabit_env = []
for k, v in worker_env.items():
rabit_env.append(f"{k}={v}".encode())
with xgb.rabit.RabitContext(rabit_env):
ret = xgb.rabit.broadcast('test1234', 0)
assert str(ret) == 'test1234'
ret = xgb.rabit.broadcast("test1234", 0)
assert str(ret) == "test1234"
def run_rabit_ops(client, n_workers):
from test_with_dask import _get_client_workers
from xgboost.dask import RabitContext, _get_rabit_args
from xgboost.dask import RabitContext, _get_dask_config, _get_rabit_args
from xgboost import rabit
workers = _get_client_workers(client)
rabit_args = client.sync(_get_rabit_args, len(workers), None, client)
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
assert not rabit.is_distributed()
n_workers_from_dask = len(workers)
assert n_workers == n_workers_from_dask
@@ -55,12 +58,26 @@ def run_rabit_ops(client, n_workers):
@pytest.mark.skipif(**tm.no_dask())
def test_rabit_ops():
from distributed import Client, LocalCluster
n_workers = 3
with LocalCluster(n_workers=n_workers) as cluster:
with Client(cluster) as client:
run_rabit_ops(client, n_workers)
@pytest.mark.skipif(**testing.skip_ipv6())
@pytest.mark.skipif(**tm.no_dask())
def test_rabit_ops_ipv6():
import dask
from distributed import Client, LocalCluster
n_workers = 3
with dask.config.set({"xgboost.scheduler_address": "[::1]"}):
with LocalCluster(n_workers=n_workers, host="[::1]") as cluster:
with Client(cluster) as client:
run_rabit_ops(client, n_workers)
def test_rank_assignment() -> None:
from distributed import Client, LocalCluster
from test_with_dask import _get_client_workers