Revamp the rabit implementation. (#10112)

This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
This commit is contained in:
Jiaming Yuan
2024-05-20 11:56:23 +08:00
committed by GitHub
parent ba9b4cb1ee
commit a5a58102e5
195 changed files with 2768 additions and 9234 deletions

View File

@@ -1,17 +1,17 @@
"""XGBoost collective communication related API."""
import ctypes
import json
import logging
import os
import pickle
import platform
from enum import IntEnum, unique
from typing import Any, Dict, List, Optional
import numpy as np
from ._typing import _T
from .core import _LIB, _check_call, build_info, c_str, from_pystr_to_cstr, py_str
from .core import _LIB, _check_call, build_info, c_str, make_jcargs, py_str
LOGGER = logging.getLogger("[xgboost.collective]")
@@ -21,49 +21,35 @@ def init(**args: Any) -> None:
Parameters
----------
args: Dict[str, Any]
args :
Keyword arguments representing the parameters and their values.
Accepted parameters:
- xgboost_communicator: The type of the communicator. Can be set as an environment
variable.
- dmlc_communicator: The type of the communicator.
* rabit: Use Rabit. This is the default if the type is unspecified.
* federated: Use the gRPC interface for Federated Learning.
Only applicable to the Rabit communicator (these are case sensitive):
-- rabit_tracker_uri: Hostname of the tracker.
-- rabit_tracker_port: Port number of the tracker.
-- rabit_task_id: ID of the current task, can be used to obtain deterministic rank
assignment.
-- rabit_world_size: Total number of workers.
-- rabit_hadoop_mode: Enable Hadoop support.
-- rabit_tree_reduce_minsize: Minimal size for tree reduce.
-- rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
-- rabit_reduce_buffer: Size of the reduce buffer.
-- rabit_bootstrap_cache: Size of the bootstrap cache.
-- rabit_debug: Enable debugging.
-- rabit_timeout: Enable timeout.
-- rabit_timeout_sec: Timeout in seconds.
-- rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
environment variables):
-- DMLC_TRACKER_URI: Hostname of the tracker.
-- DMLC_TRACKER_PORT: Port number of the tracker.
-- DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank
assignment.
-- DMLC_ROLE: Role of the current task, "worker" or "server".
-- DMLC_NUM_ATTEMPT: Number of attempts after task failure.
-- DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
Only applicable to the Federated communicator (use upper case for environment variables, use
lower case for runtime configuration):
-- federated_server_address: Address of the federated server.
-- federated_world_size: Number of federated workers.
-- federated_rank: Rank of the current worker.
-- federated_server_cert: Server certificate file path. Only needed for the SSL mode.
-- federated_client_key: Client key file path. Only needed for the SSL mode.
-- federated_client_cert: Client certificate file path. Only needed for the SSL mode.
Only applicable to the Rabit communicator:
- dmlc_tracker_uri: Hostname of the tracker.
- dmlc_tracker_port: Port number of the tracker.
- dmlc_task_id: ID of the current task, can be used to obtain deterministic
- dmlc_retry: The number of retry when handling network errors.
- dmlc_timeout: Timeout in seconds.
- dmlc_nccl_path: Path to load (dlopen) nccl for GPU-based communication.
Only applicable to the Federated communicator (use upper case for environment
variables, use lower case for runtime configuration):
- federated_server_address: Address of the federated server.
- federated_world_size: Number of federated workers.
- federated_rank: Rank of the current worker.
- federated_server_cert: Server certificate file path. Only needed for the SSL
mode.
- federated_client_key: Client key file path. Only needed for the SSL mode.
- federated_client_cert: Client certificate file path. Only needed for the SSL
mode.
"""
config = from_pystr_to_cstr(json.dumps(args))
_check_call(_LIB.XGCommunicatorInit(config))
_check_call(_LIB.XGCommunicatorInit(make_jcargs(**args)))
def finalize() -> None:
@@ -157,7 +143,7 @@ def broadcast(data: _T, root: int) -> _T:
assert data is not None, "need to pass in data when broadcasting"
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
length.value = len(s)
# run first broadcast
# Run first broadcast
_check_call(
_LIB.XGCommunicatorBroadcast(
ctypes.byref(length), ctypes.sizeof(ctypes.c_ulong), root
@@ -184,16 +170,27 @@ def broadcast(data: _T, root: int) -> _T:
# enumeration of dtypes
DTYPE_ENUM__ = {
np.dtype("int8"): 0,
np.dtype("uint8"): 1,
np.dtype("int32"): 2,
np.dtype("uint32"): 3,
np.dtype("int64"): 4,
np.dtype("uint64"): 5,
np.dtype("float32"): 6,
np.dtype("float64"): 7,
}
def _map_dtype(dtype: np.dtype) -> int:
dtype_map = {
np.dtype("float16"): 0,
np.dtype("float32"): 1,
np.dtype("float64"): 2,
np.dtype("int8"): 4,
np.dtype("int16"): 5,
np.dtype("int32"): 6,
np.dtype("int64"): 7,
np.dtype("uint8"): 8,
np.dtype("uint16"): 9,
np.dtype("uint32"): 10,
np.dtype("uint64"): 11,
}
if platform.system() != "Windows":
dtype_map.update({np.dtype("float128"): 3})
if dtype not in dtype_map:
raise TypeError(f"data type {dtype} is not supported on the current platform.")
return dtype_map[dtype]
@unique
@@ -229,24 +226,23 @@ def allreduce(data: np.ndarray, op: Op) -> np.ndarray: # pylint:disable=invalid
"""
if not isinstance(data, np.ndarray):
raise TypeError("allreduce only takes in numpy.ndarray")
buf = data.ravel()
if buf.base is data.base:
buf = buf.copy()
if buf.dtype not in DTYPE_ENUM__:
raise TypeError(f"data type {buf.dtype} not supported")
buf = data.ravel().copy()
_check_call(
_LIB.XGCommunicatorAllreduce(
buf.ctypes.data_as(ctypes.c_void_p),
buf.size,
DTYPE_ENUM__[buf.dtype],
_map_dtype(buf.dtype),
int(op),
None,
None,
)
)
return buf
def signal_error() -> None:
"""Kill the process."""
_check_call(_LIB.XGCommunicatorSignalError())
class CommunicatorContext:
"""A context controlling collective communicator initialization and finalization."""

View File

@@ -295,7 +295,7 @@ def _check_distributed_params(kwargs: Dict[str, Any]) -> None:
if device and device.find(":") != -1:
raise ValueError(
"Distributed training doesn't support selecting device ordinal as GPUs are"
" managed by the distributed framework. use `device=cuda` or `device=gpu`"
" managed by the distributed frameworks. use `device=cuda` or `device=gpu`"
" instead."
)

View File

@@ -71,6 +71,7 @@ from xgboost.core import (
Metric,
Objective,
QuantileDMatrix,
XGBoostError,
_check_distributed_params,
_deprecate_positional_args,
_expect,
@@ -90,7 +91,7 @@ from xgboost.sklearn import (
_wrap_evaluation_matrices,
xgboost_model_doc,
)
from xgboost.tracker import RabitTracker, get_host_ip
from xgboost.tracker import RabitTracker
from xgboost.training import train as worker_train
from .utils import get_n_threads
@@ -160,36 +161,38 @@ def _try_start_tracker(
n_workers: int,
addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]],
) -> Dict[str, Union[int, str]]:
env: Dict[str, Union[int, str]] = {"DMLC_NUM_WORKER": n_workers}
env: Dict[str, Union[int, str]] = {}
try:
if isinstance(addrs[0], tuple):
host_ip = addrs[0][0]
port = addrs[0][1]
rabit_tracker = RabitTracker(
host_ip=get_host_ip(host_ip),
n_workers=n_workers,
host_ip=host_ip,
port=port,
use_logger=False,
sortby="task",
)
else:
addr = addrs[0]
assert isinstance(addr, str) or addr is None
host_ip = get_host_ip(addr)
rabit_tracker = RabitTracker(
host_ip=host_ip, n_workers=n_workers, use_logger=False, sortby="task"
n_workers=n_workers, host_ip=addr, sortby="task"
)
env.update(rabit_tracker.worker_envs())
rabit_tracker.start(n_workers)
thread = Thread(target=rabit_tracker.join)
rabit_tracker.start()
thread = Thread(target=rabit_tracker.wait_for)
thread.daemon = True
thread.start()
except socket.error as e:
if len(addrs) < 2 or e.errno != 99:
env.update(rabit_tracker.worker_args())
except XGBoostError as e:
if len(addrs) < 2:
raise
LOGGER.warning(
"Failed to bind address '%s', trying to use '%s' instead.",
"Failed to bind address '%s', trying to use '%s' instead. Error:\n %s",
str(addrs[0]),
str(addrs[1]),
str(e),
)
env = _try_start_tracker(n_workers, addrs[1:])

View File

@@ -1,45 +1,85 @@
"""XGBoost Federated Learning related API."""
"""XGBoost Experimental Federated Learning related API."""
from .core import _LIB, XGBoostError, _check_call, build_info, c_str
import ctypes
from threading import Thread
from typing import Any, Dict, Optional
from .core import _LIB, _check_call, make_jcargs
from .tracker import RabitTracker
def run_federated_server(
port: int,
world_size: int,
server_key_path: str = "",
server_cert_path: str = "",
client_cert_path: str = "",
) -> None:
"""Run the Federated Learning server.
class FederatedTracker(RabitTracker):
"""Tracker for federated training.
Parameters
----------
port : int
The port to listen on.
world_size: int
n_workers :
The number of federated workers.
server_key_path: str
Path to the server private key file. SSL is turned off if empty.
server_cert_path: str
Path to the server certificate file. SSL is turned off if empty.
client_cert_path: str
Path to the client certificate file. SSL is turned off if empty.
port :
The port to listen on.
secure :
Whether this is a secure instance. If True, then the following arguments for SSL
must be provided.
server_key_path :
Path to the server private key file.
server_cert_path :
Path to the server certificate file.
client_cert_path :
Path to the client certificate file.
"""
if build_info()["USE_FEDERATED"]:
if not server_key_path or not server_cert_path or not client_cert_path:
_check_call(_LIB.XGBRunInsecureFederatedServer(port, world_size))
else:
_check_call(
_LIB.XGBRunFederatedServer(
port,
world_size,
c_str(server_key_path),
c_str(server_cert_path),
c_str(client_cert_path),
)
)
else:
raise XGBoostError(
"XGBoost needs to be built with the federated learning plugin "
"enabled in order to use this module"
def __init__( # pylint: disable=R0913, W0231
self,
n_workers: int,
port: int,
secure: bool,
server_key_path: str = "",
server_cert_path: str = "",
client_cert_path: str = "",
timeout: int = 300,
) -> None:
handle = ctypes.c_void_p()
args = make_jcargs(
n_workers=n_workers,
port=port,
dmlc_communicator="federated",
federated_secure=secure,
server_key_path=server_key_path,
server_cert_path=server_cert_path,
client_cert_path=client_cert_path,
timeout=int(timeout),
)
_check_call(_LIB.XGTrackerCreate(args, ctypes.byref(handle)))
self.handle = handle
def run_federated_server( # pylint: disable=too-many-arguments
n_workers: int,
port: int,
server_key_path: Optional[str] = None,
server_cert_path: Optional[str] = None,
client_cert_path: Optional[str] = None,
timeout: int = 300,
) -> Dict[str, Any]:
"""See :py:class:`~xgboost.federated.FederatedTracker` for more info."""
args: Dict[str, Any] = {"n_workers": n_workers}
secure = all(
path is not None
for path in [server_key_path, server_cert_path, client_cert_path]
)
tracker = FederatedTracker(
n_workers=n_workers, port=port, secure=secure, timeout=timeout
)
tracker.start()
thread = Thread(target=tracker.wait_for)
thread.daemon = True
thread.start()
args.update(tracker.worker_args())
return args

View File

@@ -47,21 +47,21 @@ class CommunicatorContext(CCtx): # pylint: disable=too-few-public-methods
"""Context with PySpark specific task ID."""
def __init__(self, context: BarrierTaskContext, **args: Any) -> None:
args["DMLC_TASK_ID"] = str(context.partitionId())
args["dmlc_task_id"] = str(context.partitionId())
super().__init__(**args)
def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:
"""Start Rabit tracker with n_workers"""
env: Dict[str, Any] = {"DMLC_NUM_WORKER": n_workers}
args: Dict[str, Any] = {"n_workers": n_workers}
host = _get_host_ip(context)
rabit_context = RabitTracker(host_ip=host, n_workers=n_workers, sortby="task")
env.update(rabit_context.worker_envs())
rabit_context.start(n_workers)
thread = Thread(target=rabit_context.join)
tracker = RabitTracker(n_workers=n_workers, host_ip=host, sortby="task")
tracker.start()
thread = Thread(target=tracker.wait_for)
thread.daemon = True
thread.start()
return env
args.update(tracker.worker_args())
return args
def _get_rabit_args(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:

View File

@@ -111,8 +111,6 @@ def no_sklearn() -> PytestSkip:
def no_dask() -> PytestSkip:
if sys.platform.startswith("win"):
return {"reason": "Unsupported platform.", "condition": True}
return no_mod("dask")
@@ -193,6 +191,10 @@ def no_multiple(*args: Any) -> PytestSkip:
return {"condition": condition, "reason": reason}
def skip_win() -> PytestSkip:
return {"reason": "Unsupported platform.", "condition": is_windows()}
def skip_s390x() -> PytestSkip:
condition = platform.machine() == "s390x"
reason = "Known to fail on s390x"
@@ -968,18 +970,18 @@ def run_with_rabit(
exception_queue.put(e)
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=world_size)
tracker.start(world_size)
tracker.start()
workers = []
for _ in range(world_size):
worker = threading.Thread(target=run_worker, args=(tracker.worker_envs(),))
worker = threading.Thread(target=run_worker, args=(tracker.worker_args(),))
workers.append(worker)
worker.start()
for worker in workers:
worker.join()
assert exception_queue.empty(), f"Worker failed: {exception_queue.get()}"
tracker.join()
tracker.wait_for()
def column_split_feature_names(

View File

@@ -1,64 +1,12 @@
# pylint: disable=too-many-instance-attributes, too-many-arguments, too-many-branches
"""
This script is a variant of dmlc-core/dmlc_tracker/tracker.py,
which is a specialized version for xgboost tasks.
"""
import argparse
import logging
"""Tracker for XGBoost collective."""
import ctypes
import json
import socket
import struct
import sys
from threading import Thread
from typing import Dict, List, Optional, Set, Tuple, Union
from enum import IntEnum, unique
from typing import Dict, Optional, Union
_RingMap = Dict[int, Tuple[int, int]]
_TreeMap = Dict[int, List[int]]
class ExSocket:
"""
Extension of socket to handle recv and send of special data
"""
def __init__(self, sock: socket.socket) -> None:
self.sock = sock
def recvall(self, nbytes: int) -> bytes:
"""Receive number of bytes."""
res = []
nread = 0
while nread < nbytes:
chunk = self.sock.recv(min(nbytes - nread, 1024))
nread += len(chunk)
res.append(chunk)
return b"".join(res)
def recvint(self) -> int:
"""Receive an integer of 32 bytes"""
return struct.unpack("@i", self.recvall(4))[0]
def sendint(self, value: int) -> None:
"""Send an integer of 32 bytes"""
self.sock.sendall(struct.pack("@i", value))
def sendstr(self, value: str) -> None:
"""Send a Python string"""
self.sendint(len(value))
self.sock.sendall(value.encode())
def recvstr(self) -> str:
"""Receive a Python string"""
slen = self.recvint()
return self.recvall(slen).decode()
# magic number used to verify existence of data
MAGIC_NUM = 0xFF99
def get_some_ip(host: str) -> str:
"""Get ip from host"""
return socket.getaddrinfo(host, None)[0][4][0]
from .core import _LIB, _check_call, make_jcargs
def get_family(addr: str) -> int:
@@ -66,439 +14,95 @@ def get_family(addr: str) -> int:
return socket.getaddrinfo(addr, None)[0][0]
class WorkerEntry:
"""Hanlder to each worker."""
def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]):
worker = ExSocket(sock)
self.sock = worker
self.host = get_some_ip(s_addr[0])
magic = worker.recvint()
assert magic == MAGIC_NUM, f"invalid magic number={magic} from {self.host}"
worker.sendint(MAGIC_NUM)
self.rank = worker.recvint()
self.world_size = worker.recvint()
self.task_id = worker.recvstr()
self.cmd = worker.recvstr()
self.wait_accept = 0
self.port: Optional[int] = None
def print(self, use_logger: bool) -> None:
"""Execute the print command from worker."""
msg = self.sock.recvstr()
# On dask we use print to avoid setting global verbosity.
if use_logger:
logging.info(msg.strip())
else:
print(msg.strip(), flush=True)
def decide_rank(self, job_map: Dict[str, int]) -> int:
"""Get the rank of current entry."""
if self.rank >= 0:
return self.rank
if self.task_id != "NULL" and self.task_id in job_map:
return job_map[self.task_id]
return -1
def assign_rank(
self,
rank: int,
wait_conn: Dict[int, "WorkerEntry"],
tree_map: _TreeMap,
parent_map: Dict[int, int],
ring_map: _RingMap,
) -> List[int]:
"""Assign the rank for current entry."""
self.rank = rank
nnset = set(tree_map[rank])
rprev, next_rank = ring_map[rank]
self.sock.sendint(rank)
# send parent rank
self.sock.sendint(parent_map[rank])
# send world size
self.sock.sendint(len(tree_map))
self.sock.sendint(len(nnset))
# send the rprev and next link
for r in nnset:
self.sock.sendint(r)
# send prev link
if rprev not in (-1, rank):
nnset.add(rprev)
self.sock.sendint(rprev)
else:
self.sock.sendint(-1)
# send next link
if next_rank not in (-1, rank):
nnset.add(next_rank)
self.sock.sendint(next_rank)
else:
self.sock.sendint(-1)
return self._get_remote(wait_conn, nnset)
def _get_remote(
self, wait_conn: Dict[int, "WorkerEntry"], badset: Set[int]
) -> List[int]:
while True:
conset = []
for r in badset:
if r in wait_conn:
conset.append(r)
self.sock.sendint(len(conset))
self.sock.sendint(len(badset) - len(conset))
for r in conset:
self.sock.sendstr(wait_conn[r].host)
port = wait_conn[r].port
assert port is not None
# send port of this node to other workers so that they can call connect
self.sock.sendint(port)
self.sock.sendint(r)
nerr = self.sock.recvint()
if nerr != 0:
continue
self.port = self.sock.recvint()
rmset = []
# all connection was successuly setup
for r in conset:
wait_conn[r].wait_accept -= 1
if wait_conn[r].wait_accept == 0:
rmset.append(r)
for r in rmset:
wait_conn.pop(r, None)
self.wait_accept = len(badset) - len(conset)
return rmset
class RabitTracker:
"""
tracker for rabit
"""
def __init__(
self,
host_ip: str,
n_workers: int,
port: int = 0,
use_logger: bool = False,
sortby: str = "host",
) -> None:
"""A Python implementation of RABIT tracker.
Parameters
..........
use_logger:
Use logging.info for tracker print command. When set to False, Python print
function is used instead.
sortby:
How to sort the workers for rank assignment. The default is host, but users
can set the `DMLC_TASK_ID` via RABIT initialization arguments and obtain
deterministic rank assignment. Available options are:
- host
- task
"""
sock = socket.socket(get_family(host_ip), socket.SOCK_STREAM)
sock.bind((host_ip, port))
self.port = sock.getsockname()[1]
sock.listen(256)
self.sock = sock
self.host_ip = host_ip
self.thread: Optional[Thread] = None
self.n_workers = n_workers
self._use_logger = use_logger
self._sortby = sortby
logging.info("start listen on %s:%d", host_ip, self.port)
def __del__(self) -> None:
if hasattr(self, "sock"):
self.sock.close()
@staticmethod
def _get_neighbor(rank: int, n_workers: int) -> List[int]:
rank = rank + 1
ret = []
if rank > 1:
ret.append(rank // 2 - 1)
if rank * 2 - 1 < n_workers:
ret.append(rank * 2 - 1)
if rank * 2 < n_workers:
ret.append(rank * 2)
return ret
def worker_envs(self) -> Dict[str, Union[str, int]]:
"""
get environment variables for workers
can be passed in as args or envs
"""
return {"DMLC_TRACKER_URI": self.host_ip, "DMLC_TRACKER_PORT": self.port}
def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]:
tree_map: _TreeMap = {}
parent_map: Dict[int, int] = {}
for r in range(n_workers):
tree_map[r] = self._get_neighbor(r, n_workers)
parent_map[r] = (r + 1) // 2 - 1
return tree_map, parent_map
def find_share_ring(
self, tree_map: _TreeMap, parent_map: Dict[int, int], rank: int
) -> List[int]:
"""
get a ring structure that tends to share nodes with the tree
return a list starting from rank
"""
nset = set(tree_map[rank])
cset = nset - {parent_map[rank]}
if not cset:
return [rank]
rlst = [rank]
cnt = 0
for v in cset:
vlst = self.find_share_ring(tree_map, parent_map, v)
cnt += 1
if cnt == len(cset):
vlst.reverse()
rlst += vlst
return rlst
def get_ring(self, tree_map: _TreeMap, parent_map: Dict[int, int]) -> _RingMap:
"""
get a ring connection used to recover local data
"""
assert parent_map[0] == -1
rlst = self.find_share_ring(tree_map, parent_map, 0)
assert len(rlst) == len(tree_map)
ring_map: _RingMap = {}
n_workers = len(tree_map)
for r in range(n_workers):
rprev = (r + n_workers - 1) % n_workers
rnext = (r + 1) % n_workers
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
return ring_map
def get_link_map(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int], _RingMap]:
"""
get the link map, this is a bit hacky, call for better algorithm
to place similar nodes together
"""
tree_map, parent_map = self._get_tree(n_workers)
ring_map = self.get_ring(tree_map, parent_map)
rmap = {0: 0}
k = 0
for i in range(n_workers - 1):
k = ring_map[k][1]
rmap[k] = i + 1
ring_map_: _RingMap = {}
tree_map_: _TreeMap = {}
parent_map_: Dict[int, int] = {}
for k, v in ring_map.items():
ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]])
for k, tree_nodes in tree_map.items():
tree_map_[rmap[k]] = [rmap[x] for x in tree_nodes]
for k, parent in parent_map.items():
if k != 0:
parent_map_[rmap[k]] = rmap[parent]
else:
parent_map_[rmap[k]] = -1
return tree_map_, parent_map_, ring_map_
def _sort_pending(self, pending: List[WorkerEntry]) -> List[WorkerEntry]:
if self._sortby == "host":
pending.sort(key=lambda s: s.host)
elif self._sortby == "task":
pending.sort(key=lambda s: s.task_id)
return pending
def accept_workers(self, n_workers: int) -> None:
"""Wait for all workers to connect to the tracker."""
# set of nodes that finishes the job
shutdown: Dict[int, WorkerEntry] = {}
# set of nodes that is waiting for connections
wait_conn: Dict[int, WorkerEntry] = {}
# maps job id to rank
job_map: Dict[str, int] = {}
# list of workers that is pending to be assigned rank
pending: List[WorkerEntry] = []
# lazy initialize tree_map
tree_map = None
while len(shutdown) != n_workers:
fd, s_addr = self.sock.accept()
s = WorkerEntry(fd, s_addr)
if s.cmd == "print":
s.print(self._use_logger)
continue
if s.cmd == "shutdown":
assert s.rank >= 0 and s.rank not in shutdown
assert s.rank not in wait_conn
shutdown[s.rank] = s
logging.debug("Received %s signal from %d", s.cmd, s.rank)
continue
assert s.cmd == "start"
# lazily initialize the workers
if tree_map is None:
assert s.cmd == "start"
if s.world_size > 0:
n_workers = s.world_size
tree_map, parent_map, ring_map = self.get_link_map(n_workers)
# set of nodes that is pending for getting up
todo_nodes = list(range(n_workers))
else:
assert s.world_size in (-1, n_workers)
if s.cmd == "recover":
assert s.rank >= 0
rank = s.decide_rank(job_map)
# batch assignment of ranks
if rank == -1:
assert todo_nodes
pending.append(s)
if len(pending) == len(todo_nodes):
pending = self._sort_pending(pending)
for s in pending:
rank = todo_nodes.pop(0)
if s.task_id != "NULL":
job_map[s.task_id] = rank
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
if s.wait_accept > 0:
wait_conn[rank] = s
logging.debug(
"Received %s signal from %s; assign rank %d",
s.cmd,
s.host,
s.rank,
)
if not todo_nodes:
logging.info("@tracker All of %d nodes getting started", n_workers)
else:
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
logging.debug("Received %s signal from %d", s.cmd, s.rank)
if s.wait_accept > 0:
wait_conn[rank] = s
logging.info("@tracker All nodes finishes job")
def start(self, n_workers: int) -> None:
"""Strat the tracker, it will wait for `n_workers` to connect."""
def run() -> None:
self.accept_workers(n_workers)
self.thread = Thread(target=run, args=(), daemon=True)
self.thread.start()
def join(self) -> None:
"""Wait for the tracker to finish."""
while self.thread is not None and self.thread.is_alive():
self.thread.join(100)
def alive(self) -> bool:
"""Wether the tracker thread is alive"""
return self.thread is not None and self.thread.is_alive()
def get_host_ip(host_ip: Optional[str] = None) -> str:
"""Get the IP address of current host. If `host_ip` is not none then it will be
returned as it's
"""
if host_ip is None or host_ip == "auto":
host_ip = "ip"
if host_ip == "dns":
host_ip = socket.getfqdn()
elif host_ip == "ip":
from socket import gaierror
try:
host_ip = socket.gethostbyname(socket.getfqdn())
except gaierror:
logging.debug(
"gethostbyname(socket.getfqdn()) failed... trying on hostname()"
)
host_ip = socket.gethostbyname(socket.gethostname())
if host_ip.startswith("127."):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# doesn't have to be reachable
s.connect(("10.255.255.255", 1))
host_ip = s.getsockname()[0]
assert host_ip is not None
return host_ip
def start_rabit_tracker(args: argparse.Namespace) -> None:
"""Standalone function to start rabit tracker.
"""Tracker for the collective used in XGBoost, acting as a coordinator between
workers.
Parameters
----------
args: arguments to start the rabit tracker.
..........
sortby:
How to sort the workers for rank assignment. The default is host, but users can
set the `DMLC_TASK_ID` via RABIT initialization arguments and obtain
deterministic rank assignment. Available options are:
- host
- task
timeout :
Timeout for constructing the communication group and waiting for the tracker to
shutdown when it's instructed to, doesn't apply to communication when tracking
is running.
The timeout value should take the time of data loading and pre-processing into
account, due to potential lazy execution.
The :py:meth:`.wait_for` method has a different timeout parameter that can stop
the tracker even if the tracker is still being used. A value error is raised
when timeout is reached.
"""
envs = {"DMLC_NUM_WORKER": args.num_workers, "DMLC_NUM_SERVER": args.num_servers}
rabit = RabitTracker(
host_ip=get_host_ip(args.host_ip), n_workers=args.num_workers, use_logger=True
)
envs.update(rabit.worker_envs())
rabit.start(args.num_workers)
sys.stdout.write("DMLC_TRACKER_ENV_START\n")
# simply write configuration to stdout
for k, v in envs.items():
sys.stdout.write(f"{k}={v}\n")
sys.stdout.write("DMLC_TRACKER_ENV_END\n")
sys.stdout.flush()
rabit.join()
@unique
class _SortBy(IntEnum):
HOST = 0
TASK = 1
def main() -> None:
"""Main function if tracker is executed in standalone mode."""
parser = argparse.ArgumentParser(description="Rabit Tracker start.")
parser.add_argument(
"--num-workers",
required=True,
type=int,
help="Number of worker process to be launched.",
)
parser.add_argument(
"--num-servers",
default=0,
type=int,
help="Number of server process to be launched. Only used in PS jobs.",
)
parser.add_argument(
"--host-ip",
default=None,
type=str,
help=(
"Host IP addressed, this is only needed "
+ "if the host IP cannot be automatically guessed."
),
)
parser.add_argument(
"--log-level",
default="INFO",
type=str,
choices=["INFO", "DEBUG"],
help="Logging level of the logger.",
)
args = parser.parse_args()
def __init__( # pylint: disable=too-many-arguments
self,
n_workers: int,
host_ip: Optional[str],
port: int = 0,
sortby: str = "host",
timeout: int = 0,
) -> None:
fmt = "%(asctime)s %(levelname)s %(message)s"
if args.log_level == "INFO":
level = logging.INFO
elif args.log_level == "DEBUG":
level = logging.DEBUG
else:
raise RuntimeError(f"Unknown logging level {args.log_level}")
handle = ctypes.c_void_p()
if sortby not in ("host", "task"):
raise ValueError("Expecting either 'host' or 'task' for sortby.")
if host_ip is not None:
get_family(host_ip) # use python socket to stop early for invalid address
args = make_jcargs(
host=host_ip,
n_workers=n_workers,
port=port,
dmlc_communicator="rabit",
sortby=self._SortBy.HOST if sortby == "host" else self._SortBy.TASK,
timeout=int(timeout),
)
_check_call(_LIB.XGTrackerCreate(args, ctypes.byref(handle)))
self.handle = handle
logging.basicConfig(format=fmt, level=level)
def free(self) -> None:
"""Internal function for testing."""
if hasattr(self, "handle"):
handle = self.handle
del self.handle
_check_call(_LIB.XGTrackerFree(handle))
if args.num_servers == 0:
start_rabit_tracker(args)
else:
raise RuntimeError("Do not yet support start ps tracker in standalone mode.")
def __del__(self) -> None:
self.free()
def start(self) -> None:
"""Start the tracker. Once started, the client still need to call the
:py:meth:`wait_for` method in order to wait for it to finish (think of it as a
thread).
if __name__ == "__main__":
main()
"""
_check_call(_LIB.XGTrackerRun(self.handle, make_jcargs()))
def wait_for(self, timeout: Optional[int] = None) -> None:
"""Wait for the tracker to finish all the work and shutdown. When timeout is
reached, a value error is raised. By default we don't have timeout since we
don't know how long it takes for the model to finish training.
"""
_check_call(_LIB.XGTrackerWaitFor(self.handle, make_jcargs(timeout=timeout)))
def worker_args(self) -> Dict[str, Union[str, int]]:
"""Get arguments for workers."""
c_env = ctypes.c_char_p()
_check_call(_LIB.XGTrackerWorkerArgs(self.handle, ctypes.byref(c_env)))
assert c_env.value is not None
env = json.loads(c_env.value)
return env