Initial support for IPv6 (#8225)

- Merge rabit socket into XGBoost.
- Dask interface support.
- Add test to the socket.
This commit is contained in:
Jiaming Yuan
2022-09-21 18:06:50 +08:00
committed by GitHub
parent 7d43e74e71
commit b791446623
17 changed files with 924 additions and 595 deletions

View File

@@ -52,6 +52,7 @@ from typing import (
Sequence,
Set,
Tuple,
TypedDict,
TypeVar,
Union,
)
@@ -102,19 +103,13 @@ else:
_DaskCollection = Union["da.Array", "dd.DataFrame", "dd.Series"]
_DataT = Union["da.Array", "dd.DataFrame"] # do not use series as predictor
try:
from mypy_extensions import TypedDict
TrainReturnT = TypedDict(
"TrainReturnT",
{
"booster": Booster,
"history": Dict,
},
)
except ImportError:
TrainReturnT = Dict[str, Any] # type:ignore
TrainReturnT = TypedDict(
"TrainReturnT",
{
"booster": Booster,
"history": Dict,
},
)
__all__ = [
"RabitContext",
@@ -832,11 +827,15 @@ async def _get_rabit_args(
if k not in valid_config:
raise ValueError(f"Unknown configuration: {k}")
host_ip = dconfig.get("scheduler_address", None)
if host_ip is not None and host_ip.startswith("[") and host_ip.endswith("]"):
# convert dask bracket format to proper IPv6 address.
host_ip = host_ip[1:-1]
if host_ip is not None:
try:
host_ip, port = distributed.comm.get_address_host_port(host_ip)
except ValueError:
pass
if host_ip is not None:
user_addr = (host_ip, port)
else:

View File

@@ -0,0 +1,41 @@
"""Utilities for defining Python tests."""
import socket
from platform import system
from typing import TypedDict
PytestSkip = TypedDict("PytestSkip", {"condition": bool, "reason": str})
def has_ipv6() -> bool:
"""Check whether IPv6 is enabled on this host."""
# connection error in macos, still need some fixes.
if system() not in ("Linux", "Windows"):
return False
if socket.has_ipv6:
try:
with socket.socket(
socket.AF_INET6, socket.SOCK_STREAM
) as server, socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as client:
server.bind(("::1", 0))
port = server.getsockname()[1]
server.listen()
client.connect(("::1", port))
conn, _ = server.accept()
client.sendall("abc".encode())
msg = conn.recv(3).decode()
# if the code can be executed to this point, the message should be
# correct.
assert msg == "abc"
return True
except OSError:
pass
return False
def skip_ipv6() -> PytestSkip:
"""PyTest skip mark for IPv6."""
return {"condition": not has_ipv6(), "reason": "IPv6 is required to be enabled."}

View File

@@ -112,7 +112,7 @@ class WorkerEntry:
"""Assign the rank for current entry."""
self.rank = rank
nnset = set(tree_map[rank])
rprev, rnext = ring_map[rank]
rprev, next_rank = ring_map[rank]
self.sock.sendint(rank)
# send parent rank
self.sock.sendint(parent_map[rank])
@@ -129,9 +129,9 @@ class WorkerEntry:
else:
self.sock.sendint(-1)
# send next link
if rnext not in (-1, rank):
nnset.add(rnext)
self.sock.sendint(rnext)
if next_rank not in (-1, rank):
nnset.add(next_rank)
self.sock.sendint(next_rank)
else:
self.sock.sendint(-1)
@@ -157,6 +157,7 @@ class WorkerEntry:
self.sock.sendstr(wait_conn[r].host)
port = wait_conn[r].port
assert port is not None
# send port of this node to other workers so that they can call connect
self.sock.sendint(port)
self.sock.sendint(r)
nerr = self.sock.recvint()