[dask] Improve configuration for port. (#7645)

- Try port 0 to let the OS return the available port.
- Add port configuration.
This commit is contained in:
Jiaming Yuan 2022-02-14 21:34:34 +08:00 committed by GitHub
parent b52c4e13b0
commit 5cd1f71b51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 51 additions and 25 deletions

View File

@ -478,7 +478,7 @@ interface, including callback functions, custom evaluation metric and objective:
.. _tracker-ip: .. _tracker-ip:
*************** ***************
Tracker Host IP Troubleshooting
*************** ***************
.. versionadded:: 1.6.0 .. versionadded:: 1.6.0
@ -499,7 +499,10 @@ dask config is used:
with Client(scheduler_file="sched.json") as client: with Client(scheduler_file="sched.json") as client:
reg = dxgb.DaskXGBRegressor() reg = dxgb.DaskXGBRegressor()
XGBoost will read configuration before training. # or we can specify the port too
with dask.config.set({"xgboost.scheduler_address": "192.0.0.100:12345"}):
reg = dxgb.DaskXGBRegressor()
***************************************************************************** *****************************************************************************
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors

View File

@ -27,6 +27,8 @@ Optional dask configuration
.. code-block:: python .. code-block:: python
dask.config.set({"xgboost.scheduler_address": "192.0.0.100"}) dask.config.set({"xgboost.scheduler_address": "192.0.0.100"})
# We can also specify the port.
dask.config.set({"xgboost.scheduler_address": "192.0.0.100:12345"})
""" """
import platform import platform
@ -160,10 +162,22 @@ def _multi_lock() -> Any:
def _try_start_tracker( def _try_start_tracker(
n_workers: int, addrs: List[Optional[str]] n_workers: int,
addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]],
) -> Dict[str, Union[int, str]]: ) -> Dict[str, Union[int, str]]:
env: Dict[str, Union[int, str]] = {"DMLC_NUM_WORKER": n_workers} env: Dict[str, Union[int, str]] = {"DMLC_NUM_WORKER": n_workers}
try: try:
if isinstance(addrs[0], tuple):
host_ip = addrs[0][0]
port = addrs[0][1]
rabit_context = RabitTracker(
hostIP=get_host_ip(host_ip),
n_workers=n_workers,
port=port,
use_logger=False,
)
else:
assert isinstance(addrs[0], str) or addrs[0] is None
rabit_context = RabitTracker( rabit_context = RabitTracker(
hostIP=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False hostIP=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False
) )
@ -186,7 +200,9 @@ def _try_start_tracker(
def _start_tracker( def _start_tracker(
n_workers: int, addr_from_dask: Optional[str], addr_from_user: Optional[str] n_workers: int,
addr_from_dask: Optional[str],
addr_from_user: Optional[Tuple[str, int]],
) -> Dict[str, Union[int, str]]: ) -> Dict[str, Union[int, str]]:
"""Start Rabit tracker, recurse to try different addresses.""" """Start Rabit tracker, recurse to try different addresses."""
env = _try_start_tracker(n_workers, [addr_from_user, addr_from_dask]) env = _try_start_tracker(n_workers, [addr_from_user, addr_from_dask])
@ -830,13 +846,22 @@ async def _get_rabit_args(
# We try 1 and 3 if 1 is available, otherwise 2 and 3. # We try 1 and 3 if 1 is available, otherwise 2 and 3.
valid_config = ["scheduler_address"] valid_config = ["scheduler_address"]
# See if user config is available # See if user config is available
host_ip: Optional[str] = None
port: int = 0
if dconfig is not None: if dconfig is not None:
for k in dconfig: for k in dconfig:
if k not in valid_config: if k not in valid_config:
raise ValueError(f"Unknown configuration: {k}") raise ValueError(f"Unknown configuration: {k}")
host_ip: Optional[str] = dconfig.get("scheduler_address", None) host_ip = dconfig.get("scheduler_address", None)
try:
host_ip, port = distributed.comm.get_address_host_port(host_ip)
except ValueError:
pass
if host_ip is not None:
user_addr = (host_ip, port)
else: else:
host_ip = None user_addr = None
# Try address from dask scheduler, this might not work, see # Try address from dask scheduler, this might not work, see
# https://github.com/dask/dask-xgboost/pull/40 # https://github.com/dask/dask-xgboost/pull/40
try: try:
@ -845,7 +870,9 @@ async def _get_rabit_args(
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
sched_addr = None sched_addr = None
env = await client.run_on_scheduler(_start_tracker, n_workers, sched_addr, host_ip) env = await client.run_on_scheduler(
_start_tracker, n_workers, sched_addr, user_addr
)
rabit_args = [f"{k}={v}".encode() for k, v in env.items()] rabit_args = [f"{k}={v}".encode() for k, v in env.items()]
return rabit_args return rabit_args

View File

@ -160,8 +160,7 @@ class RabitTracker:
def __init__( def __init__(
self, hostIP: str, self, hostIP: str,
n_workers: int, n_workers: int,
port: int = 9091, port: int = 0,
port_end: int = 9999,
use_logger: bool = False, use_logger: bool = False,
) -> None: ) -> None:
"""A Python implementation of RABIT tracker. """A Python implementation of RABIT tracker.
@ -174,15 +173,8 @@ class RabitTracker:
""" """
sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM) sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM)
for _port in range(port, port_end): sock.bind((hostIP, port))
try: self.port = sock.getsockname()[1]
sock.bind((hostIP, _port))
self.port = _port
break
except socket.error as e:
if e.errno in [98, 48]:
continue
raise
sock.listen(256) sock.listen(256)
self.sock = sock self.sock = sock
self.hostIP = hostIP self.hostIP = hostIP

View File

@ -12,10 +12,10 @@ if sys.platform.startswith("win"):
def test_rabit_tracker(): def test_rabit_tracker():
tracker = RabitTracker(hostIP='127.0.0.1', n_workers=1) tracker = RabitTracker(hostIP='127.0.0.1', n_workers=1)
tracker.start(1) tracker.start(1)
rabit_env = [ worker_env = tracker.worker_envs()
str.encode('DMLC_TRACKER_URI=127.0.0.1'), rabit_env = []
str.encode('DMLC_TRACKER_PORT=9091'), for k, v in worker_env.items():
str.encode('DMLC_TASK_ID=0')] rabit_env.append(f"{k}={v}".encode())
xgb.rabit.init(rabit_env) xgb.rabit.init(rabit_env)
ret = xgb.rabit.broadcast('test1234', 0) ret = xgb.rabit.broadcast('test1234', 0)
assert str(ret) == 'test1234' assert str(ret) == 'test1234'

View File

@ -1228,6 +1228,10 @@ class TestWithDask:
with pytest.raises(ValueError): with pytest.raises(ValueError):
xgb.dask.train(client, {}, dtrain, num_boost_round=4) xgb.dask.train(client, {}, dtrain, num_boost_round=4)
with dask.config.set({'xgboost.scheduler_address': "127.0.0.1:22"}):
with pytest.raises(PermissionError):
xgb.dask.train(client, {}, dtrain, num_boost_round=1)
def run_updater_test( def run_updater_test(
self, self,
client: "Client", client: "Client",