Revamp the rabit implementation. (#10112)
This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features: - Federated learning for both CPU and GPU. - NCCL. - More data types. - A unified interface for all the underlying implementations. - Improved timeout handling for both tracker and workers. - Exhausted tests with metrics (fixed a couple of bugs along the way). - A reusable tracker for Python and JVM packages.
This commit is contained in:
@@ -1,17 +1,17 @@
|
||||
"""XGBoost collective communication related API."""
|
||||
|
||||
import ctypes
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import platform
|
||||
from enum import IntEnum, unique
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._typing import _T
|
||||
from .core import _LIB, _check_call, build_info, c_str, from_pystr_to_cstr, py_str
|
||||
from .core import _LIB, _check_call, build_info, c_str, make_jcargs, py_str
|
||||
|
||||
LOGGER = logging.getLogger("[xgboost.collective]")
|
||||
|
||||
@@ -21,49 +21,35 @@ def init(**args: Any) -> None:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args: Dict[str, Any]
|
||||
args :
|
||||
Keyword arguments representing the parameters and their values.
|
||||
|
||||
Accepted parameters:
|
||||
- xgboost_communicator: The type of the communicator. Can be set as an environment
|
||||
variable.
|
||||
- dmlc_communicator: The type of the communicator.
|
||||
* rabit: Use Rabit. This is the default if the type is unspecified.
|
||||
* federated: Use the gRPC interface for Federated Learning.
|
||||
Only applicable to the Rabit communicator (these are case sensitive):
|
||||
-- rabit_tracker_uri: Hostname of the tracker.
|
||||
-- rabit_tracker_port: Port number of the tracker.
|
||||
-- rabit_task_id: ID of the current task, can be used to obtain deterministic rank
|
||||
assignment.
|
||||
-- rabit_world_size: Total number of workers.
|
||||
-- rabit_hadoop_mode: Enable Hadoop support.
|
||||
-- rabit_tree_reduce_minsize: Minimal size for tree reduce.
|
||||
-- rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
|
||||
-- rabit_reduce_buffer: Size of the reduce buffer.
|
||||
-- rabit_bootstrap_cache: Size of the bootstrap cache.
|
||||
-- rabit_debug: Enable debugging.
|
||||
-- rabit_timeout: Enable timeout.
|
||||
-- rabit_timeout_sec: Timeout in seconds.
|
||||
-- rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
|
||||
Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
|
||||
environment variables):
|
||||
-- DMLC_TRACKER_URI: Hostname of the tracker.
|
||||
-- DMLC_TRACKER_PORT: Port number of the tracker.
|
||||
-- DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank
|
||||
assignment.
|
||||
-- DMLC_ROLE: Role of the current task, "worker" or "server".
|
||||
-- DMLC_NUM_ATTEMPT: Number of attempts after task failure.
|
||||
-- DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
|
||||
Only applicable to the Federated communicator (use upper case for environment variables, use
|
||||
lower case for runtime configuration):
|
||||
-- federated_server_address: Address of the federated server.
|
||||
-- federated_world_size: Number of federated workers.
|
||||
-- federated_rank: Rank of the current worker.
|
||||
-- federated_server_cert: Server certificate file path. Only needed for the SSL mode.
|
||||
-- federated_client_key: Client key file path. Only needed for the SSL mode.
|
||||
-- federated_client_cert: Client certificate file path. Only needed for the SSL mode.
|
||||
|
||||
Only applicable to the Rabit communicator:
|
||||
- dmlc_tracker_uri: Hostname of the tracker.
|
||||
- dmlc_tracker_port: Port number of the tracker.
|
||||
- dmlc_task_id: ID of the current task, can be used to obtain deterministic
|
||||
- dmlc_retry: The number of retry when handling network errors.
|
||||
- dmlc_timeout: Timeout in seconds.
|
||||
- dmlc_nccl_path: Path to load (dlopen) nccl for GPU-based communication.
|
||||
|
||||
Only applicable to the Federated communicator (use upper case for environment
|
||||
variables, use lower case for runtime configuration):
|
||||
|
||||
- federated_server_address: Address of the federated server.
|
||||
- federated_world_size: Number of federated workers.
|
||||
- federated_rank: Rank of the current worker.
|
||||
- federated_server_cert: Server certificate file path. Only needed for the SSL
|
||||
mode.
|
||||
- federated_client_key: Client key file path. Only needed for the SSL mode.
|
||||
- federated_client_cert: Client certificate file path. Only needed for the SSL
|
||||
mode.
|
||||
"""
|
||||
config = from_pystr_to_cstr(json.dumps(args))
|
||||
_check_call(_LIB.XGCommunicatorInit(config))
|
||||
_check_call(_LIB.XGCommunicatorInit(make_jcargs(**args)))
|
||||
|
||||
|
||||
def finalize() -> None:
|
||||
@@ -157,7 +143,7 @@ def broadcast(data: _T, root: int) -> _T:
|
||||
assert data is not None, "need to pass in data when broadcasting"
|
||||
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
length.value = len(s)
|
||||
# run first broadcast
|
||||
# Run first broadcast
|
||||
_check_call(
|
||||
_LIB.XGCommunicatorBroadcast(
|
||||
ctypes.byref(length), ctypes.sizeof(ctypes.c_ulong), root
|
||||
@@ -184,16 +170,27 @@ def broadcast(data: _T, root: int) -> _T:
|
||||
|
||||
|
||||
# enumeration of dtypes
|
||||
DTYPE_ENUM__ = {
|
||||
np.dtype("int8"): 0,
|
||||
np.dtype("uint8"): 1,
|
||||
np.dtype("int32"): 2,
|
||||
np.dtype("uint32"): 3,
|
||||
np.dtype("int64"): 4,
|
||||
np.dtype("uint64"): 5,
|
||||
np.dtype("float32"): 6,
|
||||
np.dtype("float64"): 7,
|
||||
}
|
||||
def _map_dtype(dtype: np.dtype) -> int:
|
||||
dtype_map = {
|
||||
np.dtype("float16"): 0,
|
||||
np.dtype("float32"): 1,
|
||||
np.dtype("float64"): 2,
|
||||
np.dtype("int8"): 4,
|
||||
np.dtype("int16"): 5,
|
||||
np.dtype("int32"): 6,
|
||||
np.dtype("int64"): 7,
|
||||
np.dtype("uint8"): 8,
|
||||
np.dtype("uint16"): 9,
|
||||
np.dtype("uint32"): 10,
|
||||
np.dtype("uint64"): 11,
|
||||
}
|
||||
if platform.system() != "Windows":
|
||||
dtype_map.update({np.dtype("float128"): 3})
|
||||
|
||||
if dtype not in dtype_map:
|
||||
raise TypeError(f"data type {dtype} is not supported on the current platform.")
|
||||
|
||||
return dtype_map[dtype]
|
||||
|
||||
|
||||
@unique
|
||||
@@ -229,24 +226,23 @@ def allreduce(data: np.ndarray, op: Op) -> np.ndarray: # pylint:disable=invalid
|
||||
"""
|
||||
if not isinstance(data, np.ndarray):
|
||||
raise TypeError("allreduce only takes in numpy.ndarray")
|
||||
buf = data.ravel()
|
||||
if buf.base is data.base:
|
||||
buf = buf.copy()
|
||||
if buf.dtype not in DTYPE_ENUM__:
|
||||
raise TypeError(f"data type {buf.dtype} not supported")
|
||||
buf = data.ravel().copy()
|
||||
_check_call(
|
||||
_LIB.XGCommunicatorAllreduce(
|
||||
buf.ctypes.data_as(ctypes.c_void_p),
|
||||
buf.size,
|
||||
DTYPE_ENUM__[buf.dtype],
|
||||
_map_dtype(buf.dtype),
|
||||
int(op),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
)
|
||||
return buf
|
||||
|
||||
|
||||
def signal_error() -> None:
|
||||
"""Kill the process."""
|
||||
_check_call(_LIB.XGCommunicatorSignalError())
|
||||
|
||||
|
||||
class CommunicatorContext:
|
||||
"""A context controlling collective communicator initialization and finalization."""
|
||||
|
||||
|
||||
@@ -295,7 +295,7 @@ def _check_distributed_params(kwargs: Dict[str, Any]) -> None:
|
||||
if device and device.find(":") != -1:
|
||||
raise ValueError(
|
||||
"Distributed training doesn't support selecting device ordinal as GPUs are"
|
||||
" managed by the distributed framework. use `device=cuda` or `device=gpu`"
|
||||
" managed by the distributed frameworks. use `device=cuda` or `device=gpu`"
|
||||
" instead."
|
||||
)
|
||||
|
||||
|
||||
@@ -71,6 +71,7 @@ from xgboost.core import (
|
||||
Metric,
|
||||
Objective,
|
||||
QuantileDMatrix,
|
||||
XGBoostError,
|
||||
_check_distributed_params,
|
||||
_deprecate_positional_args,
|
||||
_expect,
|
||||
@@ -90,7 +91,7 @@ from xgboost.sklearn import (
|
||||
_wrap_evaluation_matrices,
|
||||
xgboost_model_doc,
|
||||
)
|
||||
from xgboost.tracker import RabitTracker, get_host_ip
|
||||
from xgboost.tracker import RabitTracker
|
||||
from xgboost.training import train as worker_train
|
||||
|
||||
from .utils import get_n_threads
|
||||
@@ -160,36 +161,38 @@ def _try_start_tracker(
|
||||
n_workers: int,
|
||||
addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]],
|
||||
) -> Dict[str, Union[int, str]]:
|
||||
env: Dict[str, Union[int, str]] = {"DMLC_NUM_WORKER": n_workers}
|
||||
env: Dict[str, Union[int, str]] = {}
|
||||
try:
|
||||
if isinstance(addrs[0], tuple):
|
||||
host_ip = addrs[0][0]
|
||||
port = addrs[0][1]
|
||||
rabit_tracker = RabitTracker(
|
||||
host_ip=get_host_ip(host_ip),
|
||||
n_workers=n_workers,
|
||||
host_ip=host_ip,
|
||||
port=port,
|
||||
use_logger=False,
|
||||
sortby="task",
|
||||
)
|
||||
else:
|
||||
addr = addrs[0]
|
||||
assert isinstance(addr, str) or addr is None
|
||||
host_ip = get_host_ip(addr)
|
||||
rabit_tracker = RabitTracker(
|
||||
host_ip=host_ip, n_workers=n_workers, use_logger=False, sortby="task"
|
||||
n_workers=n_workers, host_ip=addr, sortby="task"
|
||||
)
|
||||
env.update(rabit_tracker.worker_envs())
|
||||
rabit_tracker.start(n_workers)
|
||||
thread = Thread(target=rabit_tracker.join)
|
||||
|
||||
rabit_tracker.start()
|
||||
thread = Thread(target=rabit_tracker.wait_for)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
except socket.error as e:
|
||||
if len(addrs) < 2 or e.errno != 99:
|
||||
env.update(rabit_tracker.worker_args())
|
||||
|
||||
except XGBoostError as e:
|
||||
if len(addrs) < 2:
|
||||
raise
|
||||
LOGGER.warning(
|
||||
"Failed to bind address '%s', trying to use '%s' instead.",
|
||||
"Failed to bind address '%s', trying to use '%s' instead. Error:\n %s",
|
||||
str(addrs[0]),
|
||||
str(addrs[1]),
|
||||
str(e),
|
||||
)
|
||||
env = _try_start_tracker(n_workers, addrs[1:])
|
||||
|
||||
|
||||
@@ -1,45 +1,85 @@
|
||||
"""XGBoost Federated Learning related API."""
|
||||
"""XGBoost Experimental Federated Learning related API."""
|
||||
|
||||
from .core import _LIB, XGBoostError, _check_call, build_info, c_str
|
||||
import ctypes
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .core import _LIB, _check_call, make_jcargs
|
||||
from .tracker import RabitTracker
|
||||
|
||||
|
||||
def run_federated_server(
|
||||
port: int,
|
||||
world_size: int,
|
||||
server_key_path: str = "",
|
||||
server_cert_path: str = "",
|
||||
client_cert_path: str = "",
|
||||
) -> None:
|
||||
"""Run the Federated Learning server.
|
||||
class FederatedTracker(RabitTracker):
|
||||
"""Tracker for federated training.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
port : int
|
||||
The port to listen on.
|
||||
world_size: int
|
||||
n_workers :
|
||||
The number of federated workers.
|
||||
server_key_path: str
|
||||
Path to the server private key file. SSL is turned off if empty.
|
||||
server_cert_path: str
|
||||
Path to the server certificate file. SSL is turned off if empty.
|
||||
client_cert_path: str
|
||||
Path to the client certificate file. SSL is turned off if empty.
|
||||
|
||||
port :
|
||||
The port to listen on.
|
||||
|
||||
secure :
|
||||
Whether this is a secure instance. If True, then the following arguments for SSL
|
||||
must be provided.
|
||||
|
||||
server_key_path :
|
||||
Path to the server private key file.
|
||||
|
||||
server_cert_path :
|
||||
Path to the server certificate file.
|
||||
|
||||
client_cert_path :
|
||||
Path to the client certificate file.
|
||||
|
||||
"""
|
||||
if build_info()["USE_FEDERATED"]:
|
||||
if not server_key_path or not server_cert_path or not client_cert_path:
|
||||
_check_call(_LIB.XGBRunInsecureFederatedServer(port, world_size))
|
||||
else:
|
||||
_check_call(
|
||||
_LIB.XGBRunFederatedServer(
|
||||
port,
|
||||
world_size,
|
||||
c_str(server_key_path),
|
||||
c_str(server_cert_path),
|
||||
c_str(client_cert_path),
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise XGBoostError(
|
||||
"XGBoost needs to be built with the federated learning plugin "
|
||||
"enabled in order to use this module"
|
||||
|
||||
def __init__( # pylint: disable=R0913, W0231
|
||||
self,
|
||||
n_workers: int,
|
||||
port: int,
|
||||
secure: bool,
|
||||
server_key_path: str = "",
|
||||
server_cert_path: str = "",
|
||||
client_cert_path: str = "",
|
||||
timeout: int = 300,
|
||||
) -> None:
|
||||
handle = ctypes.c_void_p()
|
||||
args = make_jcargs(
|
||||
n_workers=n_workers,
|
||||
port=port,
|
||||
dmlc_communicator="federated",
|
||||
federated_secure=secure,
|
||||
server_key_path=server_key_path,
|
||||
server_cert_path=server_cert_path,
|
||||
client_cert_path=client_cert_path,
|
||||
timeout=int(timeout),
|
||||
)
|
||||
_check_call(_LIB.XGTrackerCreate(args, ctypes.byref(handle)))
|
||||
self.handle = handle
|
||||
|
||||
|
||||
def run_federated_server( # pylint: disable=too-many-arguments
|
||||
n_workers: int,
|
||||
port: int,
|
||||
server_key_path: Optional[str] = None,
|
||||
server_cert_path: Optional[str] = None,
|
||||
client_cert_path: Optional[str] = None,
|
||||
timeout: int = 300,
|
||||
) -> Dict[str, Any]:
|
||||
"""See :py:class:`~xgboost.federated.FederatedTracker` for more info."""
|
||||
args: Dict[str, Any] = {"n_workers": n_workers}
|
||||
secure = all(
|
||||
path is not None
|
||||
for path in [server_key_path, server_cert_path, client_cert_path]
|
||||
)
|
||||
tracker = FederatedTracker(
|
||||
n_workers=n_workers, port=port, secure=secure, timeout=timeout
|
||||
)
|
||||
tracker.start()
|
||||
|
||||
thread = Thread(target=tracker.wait_for)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
args.update(tracker.worker_args())
|
||||
return args
|
||||
|
||||
@@ -47,21 +47,21 @@ class CommunicatorContext(CCtx): # pylint: disable=too-few-public-methods
|
||||
"""Context with PySpark specific task ID."""
|
||||
|
||||
def __init__(self, context: BarrierTaskContext, **args: Any) -> None:
|
||||
args["DMLC_TASK_ID"] = str(context.partitionId())
|
||||
args["dmlc_task_id"] = str(context.partitionId())
|
||||
super().__init__(**args)
|
||||
|
||||
|
||||
def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:
|
||||
"""Start Rabit tracker with n_workers"""
|
||||
env: Dict[str, Any] = {"DMLC_NUM_WORKER": n_workers}
|
||||
args: Dict[str, Any] = {"n_workers": n_workers}
|
||||
host = _get_host_ip(context)
|
||||
rabit_context = RabitTracker(host_ip=host, n_workers=n_workers, sortby="task")
|
||||
env.update(rabit_context.worker_envs())
|
||||
rabit_context.start(n_workers)
|
||||
thread = Thread(target=rabit_context.join)
|
||||
tracker = RabitTracker(n_workers=n_workers, host_ip=host, sortby="task")
|
||||
tracker.start()
|
||||
thread = Thread(target=tracker.wait_for)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
return env
|
||||
args.update(tracker.worker_args())
|
||||
return args
|
||||
|
||||
|
||||
def _get_rabit_args(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:
|
||||
|
||||
@@ -111,8 +111,6 @@ def no_sklearn() -> PytestSkip:
|
||||
|
||||
|
||||
def no_dask() -> PytestSkip:
|
||||
if sys.platform.startswith("win"):
|
||||
return {"reason": "Unsupported platform.", "condition": True}
|
||||
return no_mod("dask")
|
||||
|
||||
|
||||
@@ -193,6 +191,10 @@ def no_multiple(*args: Any) -> PytestSkip:
|
||||
return {"condition": condition, "reason": reason}
|
||||
|
||||
|
||||
def skip_win() -> PytestSkip:
|
||||
return {"reason": "Unsupported platform.", "condition": is_windows()}
|
||||
|
||||
|
||||
def skip_s390x() -> PytestSkip:
|
||||
condition = platform.machine() == "s390x"
|
||||
reason = "Known to fail on s390x"
|
||||
@@ -968,18 +970,18 @@ def run_with_rabit(
|
||||
exception_queue.put(e)
|
||||
|
||||
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=world_size)
|
||||
tracker.start(world_size)
|
||||
tracker.start()
|
||||
|
||||
workers = []
|
||||
for _ in range(world_size):
|
||||
worker = threading.Thread(target=run_worker, args=(tracker.worker_envs(),))
|
||||
worker = threading.Thread(target=run_worker, args=(tracker.worker_args(),))
|
||||
workers.append(worker)
|
||||
worker.start()
|
||||
for worker in workers:
|
||||
worker.join()
|
||||
assert exception_queue.empty(), f"Worker failed: {exception_queue.get()}"
|
||||
|
||||
tracker.join()
|
||||
tracker.wait_for()
|
||||
|
||||
|
||||
def column_split_feature_names(
|
||||
|
||||
@@ -1,64 +1,12 @@
|
||||
# pylint: disable=too-many-instance-attributes, too-many-arguments, too-many-branches
|
||||
"""
|
||||
This script is a variant of dmlc-core/dmlc_tracker/tracker.py,
|
||||
which is a specialized version for xgboost tasks.
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
"""Tracker for XGBoost collective."""
|
||||
|
||||
import ctypes
|
||||
import json
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
from threading import Thread
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
from enum import IntEnum, unique
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
_RingMap = Dict[int, Tuple[int, int]]
|
||||
_TreeMap = Dict[int, List[int]]
|
||||
|
||||
|
||||
class ExSocket:
|
||||
"""
|
||||
Extension of socket to handle recv and send of special data
|
||||
"""
|
||||
|
||||
def __init__(self, sock: socket.socket) -> None:
|
||||
self.sock = sock
|
||||
|
||||
def recvall(self, nbytes: int) -> bytes:
|
||||
"""Receive number of bytes."""
|
||||
res = []
|
||||
nread = 0
|
||||
while nread < nbytes:
|
||||
chunk = self.sock.recv(min(nbytes - nread, 1024))
|
||||
nread += len(chunk)
|
||||
res.append(chunk)
|
||||
return b"".join(res)
|
||||
|
||||
def recvint(self) -> int:
|
||||
"""Receive an integer of 32 bytes"""
|
||||
return struct.unpack("@i", self.recvall(4))[0]
|
||||
|
||||
def sendint(self, value: int) -> None:
|
||||
"""Send an integer of 32 bytes"""
|
||||
self.sock.sendall(struct.pack("@i", value))
|
||||
|
||||
def sendstr(self, value: str) -> None:
|
||||
"""Send a Python string"""
|
||||
self.sendint(len(value))
|
||||
self.sock.sendall(value.encode())
|
||||
|
||||
def recvstr(self) -> str:
|
||||
"""Receive a Python string"""
|
||||
slen = self.recvint()
|
||||
return self.recvall(slen).decode()
|
||||
|
||||
|
||||
# magic number used to verify existence of data
|
||||
MAGIC_NUM = 0xFF99
|
||||
|
||||
|
||||
def get_some_ip(host: str) -> str:
|
||||
"""Get ip from host"""
|
||||
return socket.getaddrinfo(host, None)[0][4][0]
|
||||
from .core import _LIB, _check_call, make_jcargs
|
||||
|
||||
|
||||
def get_family(addr: str) -> int:
|
||||
@@ -66,439 +14,95 @@ def get_family(addr: str) -> int:
|
||||
return socket.getaddrinfo(addr, None)[0][0]
|
||||
|
||||
|
||||
class WorkerEntry:
|
||||
"""Hanlder to each worker."""
|
||||
|
||||
def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]):
|
||||
worker = ExSocket(sock)
|
||||
self.sock = worker
|
||||
self.host = get_some_ip(s_addr[0])
|
||||
magic = worker.recvint()
|
||||
assert magic == MAGIC_NUM, f"invalid magic number={magic} from {self.host}"
|
||||
worker.sendint(MAGIC_NUM)
|
||||
self.rank = worker.recvint()
|
||||
self.world_size = worker.recvint()
|
||||
self.task_id = worker.recvstr()
|
||||
self.cmd = worker.recvstr()
|
||||
self.wait_accept = 0
|
||||
self.port: Optional[int] = None
|
||||
|
||||
def print(self, use_logger: bool) -> None:
|
||||
"""Execute the print command from worker."""
|
||||
msg = self.sock.recvstr()
|
||||
# On dask we use print to avoid setting global verbosity.
|
||||
if use_logger:
|
||||
logging.info(msg.strip())
|
||||
else:
|
||||
print(msg.strip(), flush=True)
|
||||
|
||||
def decide_rank(self, job_map: Dict[str, int]) -> int:
|
||||
"""Get the rank of current entry."""
|
||||
if self.rank >= 0:
|
||||
return self.rank
|
||||
if self.task_id != "NULL" and self.task_id in job_map:
|
||||
return job_map[self.task_id]
|
||||
return -1
|
||||
|
||||
def assign_rank(
|
||||
self,
|
||||
rank: int,
|
||||
wait_conn: Dict[int, "WorkerEntry"],
|
||||
tree_map: _TreeMap,
|
||||
parent_map: Dict[int, int],
|
||||
ring_map: _RingMap,
|
||||
) -> List[int]:
|
||||
"""Assign the rank for current entry."""
|
||||
self.rank = rank
|
||||
nnset = set(tree_map[rank])
|
||||
rprev, next_rank = ring_map[rank]
|
||||
self.sock.sendint(rank)
|
||||
# send parent rank
|
||||
self.sock.sendint(parent_map[rank])
|
||||
# send world size
|
||||
self.sock.sendint(len(tree_map))
|
||||
self.sock.sendint(len(nnset))
|
||||
# send the rprev and next link
|
||||
for r in nnset:
|
||||
self.sock.sendint(r)
|
||||
# send prev link
|
||||
if rprev not in (-1, rank):
|
||||
nnset.add(rprev)
|
||||
self.sock.sendint(rprev)
|
||||
else:
|
||||
self.sock.sendint(-1)
|
||||
# send next link
|
||||
if next_rank not in (-1, rank):
|
||||
nnset.add(next_rank)
|
||||
self.sock.sendint(next_rank)
|
||||
else:
|
||||
self.sock.sendint(-1)
|
||||
|
||||
return self._get_remote(wait_conn, nnset)
|
||||
|
||||
def _get_remote(
|
||||
self, wait_conn: Dict[int, "WorkerEntry"], badset: Set[int]
|
||||
) -> List[int]:
|
||||
while True:
|
||||
conset = []
|
||||
for r in badset:
|
||||
if r in wait_conn:
|
||||
conset.append(r)
|
||||
self.sock.sendint(len(conset))
|
||||
self.sock.sendint(len(badset) - len(conset))
|
||||
for r in conset:
|
||||
self.sock.sendstr(wait_conn[r].host)
|
||||
port = wait_conn[r].port
|
||||
assert port is not None
|
||||
# send port of this node to other workers so that they can call connect
|
||||
self.sock.sendint(port)
|
||||
self.sock.sendint(r)
|
||||
nerr = self.sock.recvint()
|
||||
if nerr != 0:
|
||||
continue
|
||||
self.port = self.sock.recvint()
|
||||
rmset = []
|
||||
# all connection was successuly setup
|
||||
for r in conset:
|
||||
wait_conn[r].wait_accept -= 1
|
||||
if wait_conn[r].wait_accept == 0:
|
||||
rmset.append(r)
|
||||
for r in rmset:
|
||||
wait_conn.pop(r, None)
|
||||
self.wait_accept = len(badset) - len(conset)
|
||||
return rmset
|
||||
|
||||
|
||||
class RabitTracker:
|
||||
"""
|
||||
tracker for rabit
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host_ip: str,
|
||||
n_workers: int,
|
||||
port: int = 0,
|
||||
use_logger: bool = False,
|
||||
sortby: str = "host",
|
||||
) -> None:
|
||||
"""A Python implementation of RABIT tracker.
|
||||
|
||||
Parameters
|
||||
..........
|
||||
use_logger:
|
||||
Use logging.info for tracker print command. When set to False, Python print
|
||||
function is used instead.
|
||||
|
||||
sortby:
|
||||
How to sort the workers for rank assignment. The default is host, but users
|
||||
can set the `DMLC_TASK_ID` via RABIT initialization arguments and obtain
|
||||
deterministic rank assignment. Available options are:
|
||||
- host
|
||||
- task
|
||||
|
||||
"""
|
||||
sock = socket.socket(get_family(host_ip), socket.SOCK_STREAM)
|
||||
sock.bind((host_ip, port))
|
||||
self.port = sock.getsockname()[1]
|
||||
sock.listen(256)
|
||||
self.sock = sock
|
||||
self.host_ip = host_ip
|
||||
self.thread: Optional[Thread] = None
|
||||
self.n_workers = n_workers
|
||||
self._use_logger = use_logger
|
||||
self._sortby = sortby
|
||||
logging.info("start listen on %s:%d", host_ip, self.port)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if hasattr(self, "sock"):
|
||||
self.sock.close()
|
||||
|
||||
@staticmethod
|
||||
def _get_neighbor(rank: int, n_workers: int) -> List[int]:
|
||||
rank = rank + 1
|
||||
ret = []
|
||||
if rank > 1:
|
||||
ret.append(rank // 2 - 1)
|
||||
if rank * 2 - 1 < n_workers:
|
||||
ret.append(rank * 2 - 1)
|
||||
if rank * 2 < n_workers:
|
||||
ret.append(rank * 2)
|
||||
return ret
|
||||
|
||||
def worker_envs(self) -> Dict[str, Union[str, int]]:
|
||||
"""
|
||||
get environment variables for workers
|
||||
can be passed in as args or envs
|
||||
"""
|
||||
return {"DMLC_TRACKER_URI": self.host_ip, "DMLC_TRACKER_PORT": self.port}
|
||||
|
||||
def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]:
|
||||
tree_map: _TreeMap = {}
|
||||
parent_map: Dict[int, int] = {}
|
||||
for r in range(n_workers):
|
||||
tree_map[r] = self._get_neighbor(r, n_workers)
|
||||
parent_map[r] = (r + 1) // 2 - 1
|
||||
return tree_map, parent_map
|
||||
|
||||
def find_share_ring(
|
||||
self, tree_map: _TreeMap, parent_map: Dict[int, int], rank: int
|
||||
) -> List[int]:
|
||||
"""
|
||||
get a ring structure that tends to share nodes with the tree
|
||||
return a list starting from rank
|
||||
"""
|
||||
nset = set(tree_map[rank])
|
||||
cset = nset - {parent_map[rank]}
|
||||
if not cset:
|
||||
return [rank]
|
||||
rlst = [rank]
|
||||
cnt = 0
|
||||
for v in cset:
|
||||
vlst = self.find_share_ring(tree_map, parent_map, v)
|
||||
cnt += 1
|
||||
if cnt == len(cset):
|
||||
vlst.reverse()
|
||||
rlst += vlst
|
||||
return rlst
|
||||
|
||||
def get_ring(self, tree_map: _TreeMap, parent_map: Dict[int, int]) -> _RingMap:
|
||||
"""
|
||||
get a ring connection used to recover local data
|
||||
"""
|
||||
assert parent_map[0] == -1
|
||||
rlst = self.find_share_ring(tree_map, parent_map, 0)
|
||||
assert len(rlst) == len(tree_map)
|
||||
ring_map: _RingMap = {}
|
||||
n_workers = len(tree_map)
|
||||
for r in range(n_workers):
|
||||
rprev = (r + n_workers - 1) % n_workers
|
||||
rnext = (r + 1) % n_workers
|
||||
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
|
||||
return ring_map
|
||||
|
||||
def get_link_map(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int], _RingMap]:
|
||||
"""
|
||||
get the link map, this is a bit hacky, call for better algorithm
|
||||
to place similar nodes together
|
||||
"""
|
||||
tree_map, parent_map = self._get_tree(n_workers)
|
||||
ring_map = self.get_ring(tree_map, parent_map)
|
||||
rmap = {0: 0}
|
||||
k = 0
|
||||
for i in range(n_workers - 1):
|
||||
k = ring_map[k][1]
|
||||
rmap[k] = i + 1
|
||||
|
||||
ring_map_: _RingMap = {}
|
||||
tree_map_: _TreeMap = {}
|
||||
parent_map_: Dict[int, int] = {}
|
||||
for k, v in ring_map.items():
|
||||
ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]])
|
||||
for k, tree_nodes in tree_map.items():
|
||||
tree_map_[rmap[k]] = [rmap[x] for x in tree_nodes]
|
||||
for k, parent in parent_map.items():
|
||||
if k != 0:
|
||||
parent_map_[rmap[k]] = rmap[parent]
|
||||
else:
|
||||
parent_map_[rmap[k]] = -1
|
||||
return tree_map_, parent_map_, ring_map_
|
||||
|
||||
def _sort_pending(self, pending: List[WorkerEntry]) -> List[WorkerEntry]:
|
||||
if self._sortby == "host":
|
||||
pending.sort(key=lambda s: s.host)
|
||||
elif self._sortby == "task":
|
||||
pending.sort(key=lambda s: s.task_id)
|
||||
return pending
|
||||
|
||||
def accept_workers(self, n_workers: int) -> None:
|
||||
"""Wait for all workers to connect to the tracker."""
|
||||
|
||||
# set of nodes that finishes the job
|
||||
shutdown: Dict[int, WorkerEntry] = {}
|
||||
# set of nodes that is waiting for connections
|
||||
wait_conn: Dict[int, WorkerEntry] = {}
|
||||
# maps job id to rank
|
||||
job_map: Dict[str, int] = {}
|
||||
# list of workers that is pending to be assigned rank
|
||||
pending: List[WorkerEntry] = []
|
||||
# lazy initialize tree_map
|
||||
tree_map = None
|
||||
|
||||
while len(shutdown) != n_workers:
|
||||
fd, s_addr = self.sock.accept()
|
||||
s = WorkerEntry(fd, s_addr)
|
||||
if s.cmd == "print":
|
||||
s.print(self._use_logger)
|
||||
continue
|
||||
if s.cmd == "shutdown":
|
||||
assert s.rank >= 0 and s.rank not in shutdown
|
||||
assert s.rank not in wait_conn
|
||||
shutdown[s.rank] = s
|
||||
logging.debug("Received %s signal from %d", s.cmd, s.rank)
|
||||
continue
|
||||
assert s.cmd == "start"
|
||||
# lazily initialize the workers
|
||||
if tree_map is None:
|
||||
assert s.cmd == "start"
|
||||
if s.world_size > 0:
|
||||
n_workers = s.world_size
|
||||
tree_map, parent_map, ring_map = self.get_link_map(n_workers)
|
||||
# set of nodes that is pending for getting up
|
||||
todo_nodes = list(range(n_workers))
|
||||
else:
|
||||
assert s.world_size in (-1, n_workers)
|
||||
if s.cmd == "recover":
|
||||
assert s.rank >= 0
|
||||
|
||||
rank = s.decide_rank(job_map)
|
||||
# batch assignment of ranks
|
||||
if rank == -1:
|
||||
assert todo_nodes
|
||||
pending.append(s)
|
||||
if len(pending) == len(todo_nodes):
|
||||
pending = self._sort_pending(pending)
|
||||
for s in pending:
|
||||
rank = todo_nodes.pop(0)
|
||||
if s.task_id != "NULL":
|
||||
job_map[s.task_id] = rank
|
||||
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||
if s.wait_accept > 0:
|
||||
wait_conn[rank] = s
|
||||
logging.debug(
|
||||
"Received %s signal from %s; assign rank %d",
|
||||
s.cmd,
|
||||
s.host,
|
||||
s.rank,
|
||||
)
|
||||
if not todo_nodes:
|
||||
logging.info("@tracker All of %d nodes getting started", n_workers)
|
||||
else:
|
||||
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||
logging.debug("Received %s signal from %d", s.cmd, s.rank)
|
||||
if s.wait_accept > 0:
|
||||
wait_conn[rank] = s
|
||||
logging.info("@tracker All nodes finishes job")
|
||||
|
||||
def start(self, n_workers: int) -> None:
|
||||
"""Strat the tracker, it will wait for `n_workers` to connect."""
|
||||
|
||||
def run() -> None:
|
||||
self.accept_workers(n_workers)
|
||||
|
||||
self.thread = Thread(target=run, args=(), daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def join(self) -> None:
|
||||
"""Wait for the tracker to finish."""
|
||||
while self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(100)
|
||||
|
||||
def alive(self) -> bool:
|
||||
"""Wether the tracker thread is alive"""
|
||||
return self.thread is not None and self.thread.is_alive()
|
||||
|
||||
|
||||
def get_host_ip(host_ip: Optional[str] = None) -> str:
|
||||
"""Get the IP address of current host. If `host_ip` is not none then it will be
|
||||
returned as it's
|
||||
|
||||
"""
|
||||
if host_ip is None or host_ip == "auto":
|
||||
host_ip = "ip"
|
||||
|
||||
if host_ip == "dns":
|
||||
host_ip = socket.getfqdn()
|
||||
elif host_ip == "ip":
|
||||
from socket import gaierror
|
||||
|
||||
try:
|
||||
host_ip = socket.gethostbyname(socket.getfqdn())
|
||||
except gaierror:
|
||||
logging.debug(
|
||||
"gethostbyname(socket.getfqdn()) failed... trying on hostname()"
|
||||
)
|
||||
host_ip = socket.gethostbyname(socket.gethostname())
|
||||
if host_ip.startswith("127."):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
# doesn't have to be reachable
|
||||
s.connect(("10.255.255.255", 1))
|
||||
host_ip = s.getsockname()[0]
|
||||
|
||||
assert host_ip is not None
|
||||
return host_ip
|
||||
|
||||
|
||||
def start_rabit_tracker(args: argparse.Namespace) -> None:
|
||||
"""Standalone function to start rabit tracker.
|
||||
"""Tracker for the collective used in XGBoost, acting as a coordinator between
|
||||
workers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args: arguments to start the rabit tracker.
|
||||
..........
|
||||
sortby:
|
||||
|
||||
How to sort the workers for rank assignment. The default is host, but users can
|
||||
set the `DMLC_TASK_ID` via RABIT initialization arguments and obtain
|
||||
deterministic rank assignment. Available options are:
|
||||
- host
|
||||
- task
|
||||
|
||||
timeout :
|
||||
|
||||
Timeout for constructing the communication group and waiting for the tracker to
|
||||
shutdown when it's instructed to, doesn't apply to communication when tracking
|
||||
is running.
|
||||
|
||||
The timeout value should take the time of data loading and pre-processing into
|
||||
account, due to potential lazy execution.
|
||||
|
||||
The :py:meth:`.wait_for` method has a different timeout parameter that can stop
|
||||
the tracker even if the tracker is still being used. A value error is raised
|
||||
when timeout is reached.
|
||||
|
||||
"""
|
||||
envs = {"DMLC_NUM_WORKER": args.num_workers, "DMLC_NUM_SERVER": args.num_servers}
|
||||
rabit = RabitTracker(
|
||||
host_ip=get_host_ip(args.host_ip), n_workers=args.num_workers, use_logger=True
|
||||
)
|
||||
envs.update(rabit.worker_envs())
|
||||
rabit.start(args.num_workers)
|
||||
sys.stdout.write("DMLC_TRACKER_ENV_START\n")
|
||||
# simply write configuration to stdout
|
||||
for k, v in envs.items():
|
||||
sys.stdout.write(f"{k}={v}\n")
|
||||
sys.stdout.write("DMLC_TRACKER_ENV_END\n")
|
||||
sys.stdout.flush()
|
||||
rabit.join()
|
||||
|
||||
@unique
|
||||
class _SortBy(IntEnum):
|
||||
HOST = 0
|
||||
TASK = 1
|
||||
|
||||
def main() -> None:
|
||||
"""Main function if tracker is executed in standalone mode."""
|
||||
parser = argparse.ArgumentParser(description="Rabit Tracker start.")
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
required=True,
|
||||
type=int,
|
||||
help="Number of worker process to be launched.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-servers",
|
||||
default=0,
|
||||
type=int,
|
||||
help="Number of server process to be launched. Only used in PS jobs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host-ip",
|
||||
default=None,
|
||||
type=str,
|
||||
help=(
|
||||
"Host IP addressed, this is only needed "
|
||||
+ "if the host IP cannot be automatically guessed."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
default="INFO",
|
||||
type=str,
|
||||
choices=["INFO", "DEBUG"],
|
||||
help="Logging level of the logger.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
def __init__( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
n_workers: int,
|
||||
host_ip: Optional[str],
|
||||
port: int = 0,
|
||||
sortby: str = "host",
|
||||
timeout: int = 0,
|
||||
) -> None:
|
||||
|
||||
fmt = "%(asctime)s %(levelname)s %(message)s"
|
||||
if args.log_level == "INFO":
|
||||
level = logging.INFO
|
||||
elif args.log_level == "DEBUG":
|
||||
level = logging.DEBUG
|
||||
else:
|
||||
raise RuntimeError(f"Unknown logging level {args.log_level}")
|
||||
handle = ctypes.c_void_p()
|
||||
if sortby not in ("host", "task"):
|
||||
raise ValueError("Expecting either 'host' or 'task' for sortby.")
|
||||
if host_ip is not None:
|
||||
get_family(host_ip) # use python socket to stop early for invalid address
|
||||
args = make_jcargs(
|
||||
host=host_ip,
|
||||
n_workers=n_workers,
|
||||
port=port,
|
||||
dmlc_communicator="rabit",
|
||||
sortby=self._SortBy.HOST if sortby == "host" else self._SortBy.TASK,
|
||||
timeout=int(timeout),
|
||||
)
|
||||
_check_call(_LIB.XGTrackerCreate(args, ctypes.byref(handle)))
|
||||
self.handle = handle
|
||||
|
||||
logging.basicConfig(format=fmt, level=level)
|
||||
def free(self) -> None:
|
||||
"""Internal function for testing."""
|
||||
if hasattr(self, "handle"):
|
||||
handle = self.handle
|
||||
del self.handle
|
||||
_check_call(_LIB.XGTrackerFree(handle))
|
||||
|
||||
if args.num_servers == 0:
|
||||
start_rabit_tracker(args)
|
||||
else:
|
||||
raise RuntimeError("Do not yet support start ps tracker in standalone mode.")
|
||||
def __del__(self) -> None:
|
||||
self.free()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the tracker. Once started, the client still need to call the
|
||||
:py:meth:`wait_for` method in order to wait for it to finish (think of it as a
|
||||
thread).
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
"""
|
||||
_check_call(_LIB.XGTrackerRun(self.handle, make_jcargs()))
|
||||
|
||||
def wait_for(self, timeout: Optional[int] = None) -> None:
|
||||
"""Wait for the tracker to finish all the work and shutdown. When timeout is
|
||||
reached, a value error is raised. By default we don't have timeout since we
|
||||
don't know how long it takes for the model to finish training.
|
||||
|
||||
"""
|
||||
_check_call(_LIB.XGTrackerWaitFor(self.handle, make_jcargs(timeout=timeout)))
|
||||
|
||||
def worker_args(self) -> Dict[str, Union[str, int]]:
|
||||
"""Get arguments for workers."""
|
||||
c_env = ctypes.c_char_p()
|
||||
_check_call(_LIB.XGTrackerWorkerArgs(self.handle, ctypes.byref(c_env)))
|
||||
assert c_env.value is not None
|
||||
env = json.loads(c_env.value)
|
||||
return env
|
||||
|
||||
Reference in New Issue
Block a user