[dask] Improve configuration for port. (#7645)

- Try port 0 to let the OS return the available port.
- Add port configuration.
This commit is contained in:
Jiaming Yuan 2022-02-14 21:34:34 +08:00 committed by GitHub
parent b52c4e13b0
commit 5cd1f71b51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 51 additions and 25 deletions

View File

@ -478,7 +478,7 @@ interface, including callback functions, custom evaluation metric and objective:
.. _tracker-ip:
***************
Tracker Host IP
Troubleshooting
***************
.. versionadded:: 1.6.0
@ -499,7 +499,10 @@ dask config is used:
with Client(scheduler_file="sched.json") as client:
reg = dxgb.DaskXGBRegressor()
XGBoost will read configuration before training.
# or we can specify the port too
with dask.config.set({"xgboost.scheduler_address": "192.0.0.100:12345"}):
reg = dxgb.DaskXGBRegressor()
*****************************************************************************
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors

View File

@ -27,6 +27,8 @@ Optional dask configuration
.. code-block:: python
dask.config.set({"xgboost.scheduler_address": "192.0.0.100"})
# We can also specify the port.
dask.config.set({"xgboost.scheduler_address": "192.0.0.100:12345"})
"""
import platform
@ -160,10 +162,22 @@ def _multi_lock() -> Any:
def _try_start_tracker(
n_workers: int, addrs: List[Optional[str]]
n_workers: int,
addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]],
) -> Dict[str, Union[int, str]]:
env: Dict[str, Union[int, str]] = {"DMLC_NUM_WORKER": n_workers}
try:
if isinstance(addrs[0], tuple):
host_ip = addrs[0][0]
port = addrs[0][1]
rabit_context = RabitTracker(
hostIP=get_host_ip(host_ip),
n_workers=n_workers,
port=port,
use_logger=False,
)
else:
assert isinstance(addrs[0], str) or addrs[0] is None
rabit_context = RabitTracker(
hostIP=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False
)
@ -186,7 +200,9 @@ def _try_start_tracker(
def _start_tracker(
n_workers: int, addr_from_dask: Optional[str], addr_from_user: Optional[str]
n_workers: int,
addr_from_dask: Optional[str],
addr_from_user: Optional[Tuple[str, int]],
) -> Dict[str, Union[int, str]]:
"""Start Rabit tracker, recurse to try different addresses."""
env = _try_start_tracker(n_workers, [addr_from_user, addr_from_dask])
@ -830,13 +846,22 @@ async def _get_rabit_args(
# We try 1 and 3 if 1 is available, otherwise 2 and 3.
valid_config = ["scheduler_address"]
# See if user config is available
host_ip: Optional[str] = None
port: int = 0
if dconfig is not None:
for k in dconfig:
if k not in valid_config:
raise ValueError(f"Unknown configuration: {k}")
host_ip: Optional[str] = dconfig.get("scheduler_address", None)
host_ip = dconfig.get("scheduler_address", None)
try:
host_ip, port = distributed.comm.get_address_host_port(host_ip)
except ValueError:
pass
if host_ip is not None:
user_addr = (host_ip, port)
else:
host_ip = None
user_addr = None
# Try address from dask scheduler, this might not work, see
# https://github.com/dask/dask-xgboost/pull/40
try:
@ -845,7 +870,9 @@ async def _get_rabit_args(
except Exception: # pylint: disable=broad-except
sched_addr = None
env = await client.run_on_scheduler(_start_tracker, n_workers, sched_addr, host_ip)
env = await client.run_on_scheduler(
_start_tracker, n_workers, sched_addr, user_addr
)
rabit_args = [f"{k}={v}".encode() for k, v in env.items()]
return rabit_args

View File

@ -160,8 +160,7 @@ class RabitTracker:
def __init__(
self, hostIP: str,
n_workers: int,
port: int = 9091,
port_end: int = 9999,
port: int = 0,
use_logger: bool = False,
) -> None:
"""A Python implementation of RABIT tracker.
@ -174,15 +173,8 @@ class RabitTracker:
"""
sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM)
for _port in range(port, port_end):
try:
sock.bind((hostIP, _port))
self.port = _port
break
except socket.error as e:
if e.errno in [98, 48]:
continue
raise
sock.bind((hostIP, port))
self.port = sock.getsockname()[1]
sock.listen(256)
self.sock = sock
self.hostIP = hostIP

View File

@ -12,10 +12,10 @@ if sys.platform.startswith("win"):
def test_rabit_tracker():
tracker = RabitTracker(hostIP='127.0.0.1', n_workers=1)
tracker.start(1)
rabit_env = [
str.encode('DMLC_TRACKER_URI=127.0.0.1'),
str.encode('DMLC_TRACKER_PORT=9091'),
str.encode('DMLC_TASK_ID=0')]
worker_env = tracker.worker_envs()
rabit_env = []
for k, v in worker_env.items():
rabit_env.append(f"{k}={v}".encode())
xgb.rabit.init(rabit_env)
ret = xgb.rabit.broadcast('test1234', 0)
assert str(ret) == 'test1234'

View File

@ -1228,6 +1228,10 @@ class TestWithDask:
with pytest.raises(ValueError):
xgb.dask.train(client, {}, dtrain, num_boost_round=4)
with dask.config.set({'xgboost.scheduler_address': "127.0.0.1:22"}):
with pytest.raises(PermissionError):
xgb.dask.train(client, {}, dtrain, num_boost_round=1)
def run_updater_test(
self,
client: "Client",