[dask] Improve configuration for port. (#7645)

- Try port 0 to let the OS return the available port.
- Add port configuration.
This commit is contained in:
Jiaming Yuan
2022-02-14 21:34:34 +08:00
committed by GitHub
parent b52c4e13b0
commit 5cd1f71b51
5 changed files with 51 additions and 25 deletions

View File

@@ -27,6 +27,8 @@ Optional dask configuration
.. code-block:: python
dask.config.set({"xgboost.scheduler_address": "192.0.0.100"})
# We can also specify the port.
dask.config.set({"xgboost.scheduler_address": "192.0.0.100:12345"})
"""
import platform
@@ -160,13 +162,25 @@ def _multi_lock() -> Any:
def _try_start_tracker(
n_workers: int, addrs: List[Optional[str]]
n_workers: int,
addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]],
) -> Dict[str, Union[int, str]]:
env: Dict[str, Union[int, str]] = {"DMLC_NUM_WORKER": n_workers}
try:
rabit_context = RabitTracker(
hostIP=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False
)
if isinstance(addrs[0], tuple):
host_ip = addrs[0][0]
port = addrs[0][1]
rabit_context = RabitTracker(
hostIP=get_host_ip(host_ip),
n_workers=n_workers,
port=port,
use_logger=False,
)
else:
assert isinstance(addrs[0], str) or addrs[0] is None
rabit_context = RabitTracker(
hostIP=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False
)
env.update(rabit_context.worker_envs())
rabit_context.start(n_workers)
thread = Thread(target=rabit_context.join)
@@ -186,7 +200,9 @@ def _try_start_tracker(
def _start_tracker(
n_workers: int, addr_from_dask: Optional[str], addr_from_user: Optional[str]
n_workers: int,
addr_from_dask: Optional[str],
addr_from_user: Optional[Tuple[str, int]],
) -> Dict[str, Union[int, str]]:
"""Start Rabit tracker, recurse to try different addresses."""
env = _try_start_tracker(n_workers, [addr_from_user, addr_from_dask])
@@ -830,13 +846,22 @@ async def _get_rabit_args(
# We try 1 and 3 if 1 is available, otherwise 2 and 3.
valid_config = ["scheduler_address"]
# See if user config is available
host_ip: Optional[str] = None
port: int = 0
if dconfig is not None:
for k in dconfig:
if k not in valid_config:
raise ValueError(f"Unknown configuration: {k}")
host_ip: Optional[str] = dconfig.get("scheduler_address", None)
host_ip = dconfig.get("scheduler_address", None)
try:
host_ip, port = distributed.comm.get_address_host_port(host_ip)
except ValueError:
pass
if host_ip is not None:
user_addr = (host_ip, port)
else:
host_ip = None
user_addr = None
# Try address from dask scheduler, this might not work, see
# https://github.com/dask/dask-xgboost/pull/40
try:
@@ -845,7 +870,9 @@ async def _get_rabit_args(
except Exception: # pylint: disable=broad-except
sched_addr = None
env = await client.run_on_scheduler(_start_tracker, n_workers, sched_addr, host_ip)
env = await client.run_on_scheduler(
_start_tracker, n_workers, sched_addr, user_addr
)
rabit_args = [f"{k}={v}".encode() for k, v in env.items()]
return rabit_args

View File

@@ -160,8 +160,7 @@ class RabitTracker:
def __init__(
self, hostIP: str,
n_workers: int,
port: int = 9091,
port_end: int = 9999,
port: int = 0,
use_logger: bool = False,
) -> None:
"""A Python implementation of RABIT tracker.
@@ -174,15 +173,8 @@ class RabitTracker:
"""
sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM)
for _port in range(port, port_end):
try:
sock.bind((hostIP, _port))
self.port = _port
break
except socket.error as e:
if e.errno in [98, 48]:
continue
raise
sock.bind((hostIP, port))
self.port = sock.getsockname()[1]
sock.listen(256)
self.sock = sock
self.hostIP = hostIP