From 5cd1f71b519e82dd7dc5a04678bd771066cba26d Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 14 Feb 2022 21:34:34 +0800 Subject: [PATCH] [dask] Improve configuration for port. (#7645) - Try port 0 to let the OS return the available port. - Add port configuration. --- doc/tutorials/dask.rst | 7 +++-- python-package/xgboost/dask.py | 43 +++++++++++++++++++++++++------ python-package/xgboost/tracker.py | 14 +++------- tests/python/test_tracker.py | 8 +++--- tests/python/test_with_dask.py | 4 +++ 5 files changed, 51 insertions(+), 25 deletions(-) diff --git a/doc/tutorials/dask.rst b/doc/tutorials/dask.rst index 465f86410..44bb643a2 100644 --- a/doc/tutorials/dask.rst +++ b/doc/tutorials/dask.rst @@ -478,7 +478,7 @@ interface, including callback functions, custom evaluation metric and objective: .. _tracker-ip: *************** -Tracker Host IP +Troubleshooting *************** .. versionadded:: 1.6.0 @@ -499,7 +499,10 @@ dask config is used: with Client(scheduler_file="sched.json") as client: reg = dxgb.DaskXGBRegressor() -XGBoost will read configuration before training. + # or we can specify the port too + with dask.config.set({"xgboost.scheduler_address": "192.0.0.100:12345"}): + reg = dxgb.DaskXGBRegressor() + ***************************************************************************** Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 8184d19ea..67265c3e4 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -27,6 +27,8 @@ Optional dask configuration .. code-block:: python dask.config.set({"xgboost.scheduler_address": "192.0.0.100"}) + # We can also specify the port. + dask.config.set({"xgboost.scheduler_address": "192.0.0.100:12345"}) """ import platform @@ -160,13 +162,25 @@ def _multi_lock() -> Any: def _try_start_tracker( - n_workers: int, addrs: List[Optional[str]] + n_workers: int, + addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]], ) -> Dict[str, Union[int, str]]: env: Dict[str, Union[int, str]] = {"DMLC_NUM_WORKER": n_workers} try: - rabit_context = RabitTracker( - hostIP=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False - ) + if isinstance(addrs[0], tuple): + host_ip = addrs[0][0] + port = addrs[0][1] + rabit_context = RabitTracker( + hostIP=get_host_ip(host_ip), + n_workers=n_workers, + port=port, + use_logger=False, + ) + else: + assert isinstance(addrs[0], str) or addrs[0] is None + rabit_context = RabitTracker( + hostIP=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False + ) env.update(rabit_context.worker_envs()) rabit_context.start(n_workers) thread = Thread(target=rabit_context.join) @@ -186,7 +200,9 @@ def _try_start_tracker( def _start_tracker( - n_workers: int, addr_from_dask: Optional[str], addr_from_user: Optional[str] + n_workers: int, + addr_from_dask: Optional[str], + addr_from_user: Optional[Tuple[str, int]], ) -> Dict[str, Union[int, str]]: """Start Rabit tracker, recurse to try different addresses.""" env = _try_start_tracker(n_workers, [addr_from_user, addr_from_dask]) @@ -830,13 +846,22 @@ async def _get_rabit_args( # We try 1 and 3 if 1 is available, otherwise 2 and 3. valid_config = ["scheduler_address"] # See if user config is available + host_ip: Optional[str] = None + port: int = 0 if dconfig is not None: for k in dconfig: if k not in valid_config: raise ValueError(f"Unknown configuration: {k}") - host_ip: Optional[str] = dconfig.get("scheduler_address", None) + host_ip = dconfig.get("scheduler_address", None) + try: + host_ip, port = distributed.comm.get_address_host_port(host_ip) + except ValueError: + pass + if host_ip is not None: + user_addr = (host_ip, port) else: - host_ip = None + user_addr = None + # Try address from dask scheduler, this might not work, see # https://github.com/dask/dask-xgboost/pull/40 try: @@ -845,7 +870,9 @@ async def _get_rabit_args( except Exception: # pylint: disable=broad-except sched_addr = None - env = await client.run_on_scheduler(_start_tracker, n_workers, sched_addr, host_ip) + env = await client.run_on_scheduler( + _start_tracker, n_workers, sched_addr, user_addr + ) rabit_args = [f"{k}={v}".encode() for k, v in env.items()] return rabit_args diff --git a/python-package/xgboost/tracker.py b/python-package/xgboost/tracker.py index 9e040d05b..4412ef3e9 100644 --- a/python-package/xgboost/tracker.py +++ b/python-package/xgboost/tracker.py @@ -160,8 +160,7 @@ class RabitTracker: def __init__( self, hostIP: str, n_workers: int, - port: int = 9091, - port_end: int = 9999, + port: int = 0, use_logger: bool = False, ) -> None: """A Python implementation of RABIT tracker. @@ -174,15 +173,8 @@ class RabitTracker: """ sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM) - for _port in range(port, port_end): - try: - sock.bind((hostIP, _port)) - self.port = _port - break - except socket.error as e: - if e.errno in [98, 48]: - continue - raise + sock.bind((hostIP, port)) + self.port = sock.getsockname()[1] sock.listen(256) self.sock = sock self.hostIP = hostIP diff --git a/tests/python/test_tracker.py b/tests/python/test_tracker.py index 0ba7199eb..2f19f6933 100644 --- a/tests/python/test_tracker.py +++ b/tests/python/test_tracker.py @@ -12,10 +12,10 @@ if sys.platform.startswith("win"): def test_rabit_tracker(): tracker = RabitTracker(hostIP='127.0.0.1', n_workers=1) tracker.start(1) - rabit_env = [ - str.encode('DMLC_TRACKER_URI=127.0.0.1'), - str.encode('DMLC_TRACKER_PORT=9091'), - str.encode('DMLC_TASK_ID=0')] + worker_env = tracker.worker_envs() + rabit_env = [] + for k, v in worker_env.items(): + rabit_env.append(f"{k}={v}".encode()) xgb.rabit.init(rabit_env) ret = xgb.rabit.broadcast('test1234', 0) assert str(ret) == 'test1234' diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 272d816de..55fd22e02 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -1228,6 +1228,10 @@ class TestWithDask: with pytest.raises(ValueError): xgb.dask.train(client, {}, dtrain, num_boost_round=4) + with dask.config.set({'xgboost.scheduler_address': "127.0.0.1:22"}): + with pytest.raises(PermissionError): + xgb.dask.train(client, {}, dtrain, num_boost_round=1) + def run_updater_test( self, client: "Client",