diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 3678f6836..009dea990 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1,11 +1,9 @@ -# coding: utf-8 # pylint: disable=too-many-arguments, too-many-branches, invalid-name -# pylint: disable=too-many-lines, too-many-locals, no-self-use +# pylint: disable=too-many-lines, too-many-locals """Core XGBoost Library.""" -# pylint: disable=no-name-in-module,import-error +from abc import ABC, abstractmethod from collections.abc import Mapping from typing import List, Optional, Any, Union, Dict, TypeVar -# pylint: enable=no-name-in-module,import-error from typing import Callable, Tuple, cast, Sequence import ctypes import os @@ -123,9 +121,8 @@ def _log_callback(msg: bytes) -> None: def _get_log_callback_func() -> Callable: """Wrap log_callback() method in ctypes callback type""" - # pylint: disable=invalid-name - CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_char_p) - return CALLBACK(_log_callback) + c_callback = ctypes.CFUNCTYPE(None, ctypes.c_char_p) + return c_callback(_log_callback) def _load_lib() -> ctypes.CDLL: @@ -311,7 +308,7 @@ def _prediction_output(shape, dims, predts, is_cuda): return arr_predict -class DataIter: # pylint: disable=too-many-instance-attributes +class DataIter(ABC): # pylint: disable=too-many-instance-attributes """The interface for user defined data iterator. Parameters @@ -333,9 +330,10 @@ class DataIter: # pylint: disable=too-many-instance-attributes # Stage data in Python until reset or next is called to avoid data being free. self._temporary_data: Optional[Tuple[Any, Any]] = None - def _get_callbacks( + def get_callbacks( self, allow_host: bool, enable_categorical: bool ) -> Tuple[Callable, Callable]: + """Get callback functions for iterating in C.""" assert hasattr(self, "cache_prefix"), "__init__ is not called." self._reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)( self._reset_wrapper @@ -369,7 +367,8 @@ class DataIter: # pylint: disable=too-many-instance-attributes self._exception = e.with_traceback(tb) return dft_ret - def _reraise(self) -> None: + def reraise(self) -> None: + """Reraise the exception thrown during iteration.""" self._temporary_data = None if self._exception is not None: # pylint 2.7.0 believes `self._exception` can be None even with `assert @@ -424,10 +423,12 @@ class DataIter: # pylint: disable=too-many-instance-attributes # pylint: disable=not-callable return self._handle_exception(lambda: self.next(data_handle), 0) + @abstractmethod def reset(self) -> None: """Reset the data iterator. Prototype for user defined function.""" raise NotImplementedError() + @abstractmethod def next(self, input_data: Callable) -> int: """Set the next batch of data. @@ -642,8 +643,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes } args = from_pystr_to_cstr(json.dumps(args)) handle = ctypes.c_void_p() - # pylint: disable=protected-access - reset_callback, next_callback = it._get_callbacks( + reset_callback, next_callback = it.get_callbacks( True, enable_categorical ) ret = _LIB.XGDMatrixCreateFromCallback( @@ -654,8 +654,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes args, ctypes.byref(handle), ) - # pylint: disable=protected-access - it._reraise() + it.reraise() # delay check_call to throw intermediate exception first _check_call(ret) self.handle = handle @@ -1225,8 +1224,7 @@ class DeviceQuantileDMatrix(DMatrix): it = SingleBatchInternalIter(data=data, **meta) handle = ctypes.c_void_p() - # pylint: disable=protected-access - reset_callback, next_callback = it._get_callbacks(False, enable_categorical) + reset_callback, next_callback = it.get_callbacks(False, enable_categorical) if it.cache_prefix is not None: raise ValueError( "DeviceQuantileDMatrix doesn't cache data, remove the cache_prefix " @@ -1242,8 +1240,7 @@ class DeviceQuantileDMatrix(DMatrix): ctypes.c_int(self.max_bin), ctypes.byref(handle), ) - # pylint: disable=protected-access - it._reraise() + it.reraise() # delay check_call to throw intermediate exception first _check_call(ret) self.handle = handle @@ -1281,6 +1278,21 @@ def _get_booster_layer_trees(model: "Booster") -> Tuple[int, int]: return num_parallel_tree, num_groups +def _configure_metrics(params: Union[Dict, List]) -> Union[Dict, List]: + if ( + isinstance(params, dict) + and "eval_metric" in params + and isinstance(params["eval_metric"], list) + ): + params = dict((k, v) for k, v in params.items()) + eval_metrics = params["eval_metric"] + params.pop("eval_metric", None) + params = list(params.items()) + for eval_metric in eval_metrics: + params += [("eval_metric", eval_metric)] + return params + + class Booster: # pylint: disable=too-many-public-methods """A Booster of XGBoost. @@ -1339,7 +1351,7 @@ class Booster: raise TypeError('Unknown type:', model_file) params = params or {} - params = self._configure_metrics(params.copy()) + params = _configure_metrics(params.copy()) params = self._configure_constraints(params) if isinstance(params, list): params.append(('validate_parameters', True)) @@ -1352,17 +1364,6 @@ class Booster: else: self.booster = 'gbtree' - def _configure_metrics(self, params: Union[Dict, List]) -> Union[Dict, List]: - if isinstance(params, dict) and 'eval_metric' in params \ - and isinstance(params['eval_metric'], list): - params = dict((k, v) for k, v in params.items()) - eval_metrics = params['eval_metric'] - params.pop("eval_metric", None) - params = list(params.items()) - for eval_metric in eval_metrics: - params += [('eval_metric', eval_metric)] - return params - def _transform_monotone_constrains(self, value: Union[Dict[str, int], str]) -> str: if isinstance(value, str): return value @@ -1395,7 +1396,6 @@ class Booster: ) return s + "]" except KeyError as e: - # pylint: disable=raise-missing-from raise ValueError( "Constrained features are not a subset of training data feature names" ) from e diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 67265c3e4..b5f03c120 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -171,7 +171,7 @@ def _try_start_tracker( host_ip = addrs[0][0] port = addrs[0][1] rabit_context = RabitTracker( - hostIP=get_host_ip(host_ip), + host_ip=get_host_ip(host_ip), n_workers=n_workers, port=port, use_logger=False, @@ -179,7 +179,7 @@ def _try_start_tracker( else: assert isinstance(addrs[0], str) or addrs[0] is None rabit_context = RabitTracker( - hostIP=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False + host_ip=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False ) env.update(rabit_context.worker_envs()) rabit_context.start(n_workers) diff --git a/python-package/xgboost/rabit.py b/python-package/xgboost/rabit.py index e5bd19b9b..29723f4d0 100644 --- a/python-package/xgboost/rabit.py +++ b/python-package/xgboost/rabit.py @@ -1,7 +1,6 @@ -# coding: utf-8 -# pylint: disable= invalid-name """Distributed XGBoost Rabit related API.""" import ctypes +from enum import IntEnum, unique import pickle from typing import Any, TypeVar, Callable, Optional, cast, List, Union @@ -98,7 +97,7 @@ def get_processor_name() -> bytes: return buf.value -T = TypeVar("T") +T = TypeVar("T") # pylint:disable=invalid-name def broadcast(data: T, root: int) -> T: @@ -152,7 +151,8 @@ DTYPE_ENUM__ = { } -class Op: # pylint: disable=too-few-public-methods +@unique +class Op(IntEnum): '''Supported operations for rabit.''' MAX = 0 MIN = 1 @@ -160,18 +160,18 @@ class Op: # pylint: disable=too-few-public-methods OR = 3 -def allreduce( - data: np.ndarray, op: int, prepare_fun: Optional[Callable[[np.ndarray], None]] = None +def allreduce( # pylint:disable=invalid-name + data: np.ndarray, op: Op, prepare_fun: Optional[Callable[[np.ndarray], None]] = None ) -> np.ndarray: """Perform allreduce, return the result. Parameters ---------- - data: numpy array + data : Input data. - op: int + op : Reduction operators, can be MIN, MAX, SUM, BITOR - prepare_fun: function + prepare_fun : Lazy preprocessing function, if it is not None, prepare_fun(data) will be called by the function before performing allreduce, to initialize the data If the result of Allreduce can be recovered directly, @@ -179,7 +179,7 @@ def allreduce( Returns ------- - result : array_like + result : The result of allreduce, have same shape as data Notes @@ -196,7 +196,7 @@ def allreduce( if prepare_fun is None: _check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), buf.size, DTYPE_ENUM__[buf.dtype], - op, None, None)) + int(op), None, None)) else: func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p) diff --git a/python-package/xgboost/tracker.py b/python-package/xgboost/tracker.py index 4412ef3e9..e19181bf4 100644 --- a/python-package/xgboost/tracker.py +++ b/python-package/xgboost/tracker.py @@ -1,19 +1,16 @@ +# pylint: disable=too-many-instance-attributes, too-many-arguments, too-many-branches """ This script is a variant of dmlc-core/dmlc_tracker/tracker.py, which is a specialized version for xgboost tasks. """ - -# pylint: disable=invalid-name, missing-docstring, too-many-arguments, too-many-locals -# pylint: disable=too-many-branches, too-many-statements, too-many-instance-attributes import socket import struct -import time import logging from threading import Thread import argparse import sys -from typing import Dict, List, Tuple, Union, Optional +from typing import Dict, List, Tuple, Union, Optional, Set _RingMap = Dict[int, Tuple[int, int]] _TreeMap = Dict[int, List[int]] @@ -28,6 +25,7 @@ class ExSocket: self.sock = sock def recvall(self, nbytes: int) -> bytes: + """Receive number of bytes.""" res = [] nread = 0 while nread < nbytes: @@ -37,40 +35,47 @@ class ExSocket: return b''.join(res) def recvint(self) -> int: + """Receive an integer of 32 bytes""" return struct.unpack('@i', self.recvall(4))[0] - def sendint(self, n: int) -> None: - self.sock.sendall(struct.pack('@i', n)) + def sendint(self, value: int) -> None: + """Send an integer of 32 bytes""" + self.sock.sendall(struct.pack('@i', value)) - def sendstr(self, s: str) -> None: - self.sendint(len(s)) - self.sock.sendall(s.encode()) + def sendstr(self, value: str) -> None: + """Send a Python string""" + self.sendint(len(value)) + self.sock.sendall(value.encode()) def recvstr(self) -> str: + """Receive a Python string""" slen = self.recvint() return self.recvall(slen).decode() # magic number used to verify existence of data -kMagic = 0xff99 +MAGIC_NUM = 0xff99 def get_some_ip(host: str) -> str: + """Get ip from host""" return socket.getaddrinfo(host, None)[0][4][0] def get_family(addr: str) -> int: + """Get network family from address.""" return socket.getaddrinfo(addr, None)[0][0] class WorkerEntry: + """Hanlder to each worker.""" def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]): worker = ExSocket(sock) self.sock = worker self.host = get_some_ip(s_addr[0]) magic = worker.recvint() - assert magic == kMagic, f"invalid magic number={magic} from {self.host}" - worker.sendint(kMagic) + assert magic == MAGIC_NUM, f"invalid magic number={magic} from {self.host}" + worker.sendint(MAGIC_NUM) self.rank = worker.recvint() self.world_size = worker.recvint() self.jobid = worker.recvstr() @@ -78,7 +83,17 @@ class WorkerEntry: self.wait_accept = 0 self.port: Optional[int] = None + def print(self, use_logger: bool) -> None: + """Execute the print command from worker.""" + msg = self.sock.recvstr() + # On dask we use print to avoid setting global verbosity. + if use_logger: + logging.info(msg.strip()) + else: + print(msg.strip(), flush=True) + def decide_rank(self, job_map: Dict[str, int]) -> int: + """Get the rank of current entry.""" if self.rank >= 0: return self.rank if self.jobid != 'NULL' and self.jobid in job_map: @@ -93,6 +108,7 @@ class WorkerEntry: parent_map: Dict[int, int], ring_map: _RingMap, ) -> List[int]: + """Assign the rank for current entry.""" self.rank = rank nnset = set(tree_map[rank]) rprev, rnext = ring_map[rank] @@ -117,6 +133,12 @@ class WorkerEntry: self.sock.sendint(rnext) else: self.sock.sendint(-1) + + return self._get_remote(wait_conn, nnset) + + def _get_remote( + self, wait_conn: Dict[int, "WorkerEntry"], nnset: Set[int] + ) -> List[int]: while True: ngood = self.sock.recvint() goodset = set([]) @@ -158,10 +180,7 @@ class RabitTracker: """ def __init__( - self, hostIP: str, - n_workers: int, - port: int = 0, - use_logger: bool = False, + self, host_ip: str, n_workers: int, port: int = 0, use_logger: bool = False ) -> None: """A Python implementation of RABIT tracker. @@ -172,23 +191,23 @@ class RabitTracker: function is used instead. """ - sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM) - sock.bind((hostIP, port)) + sock = socket.socket(get_family(host_ip), socket.SOCK_STREAM) + sock.bind((host_ip, port)) self.port = sock.getsockname()[1] sock.listen(256) self.sock = sock - self.hostIP = hostIP + self.host_ip = host_ip self.thread: Optional[Thread] = None self.n_workers = n_workers self._use_logger = use_logger - logging.info('start listen on %s:%d', hostIP, self.port) + logging.info("start listen on %s:%d", host_ip, self.port) def __del__(self) -> None: if hasattr(self, "sock"): self.sock.close() @staticmethod - def get_neighbor(rank: int, n_workers: int) -> List[int]: + def _get_neighbor(rank: int, n_workers: int) -> List[int]: rank = rank + 1 ret = [] if rank > 1: @@ -204,29 +223,28 @@ class RabitTracker: get environment variables for workers can be passed in as args or envs """ - return {'DMLC_TRACKER_URI': self.hostIP, - 'DMLC_TRACKER_PORT': self.port} + return {'DMLC_TRACKER_URI': self.host_ip, 'DMLC_TRACKER_PORT': self.port} - def get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]: + def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]: tree_map: _TreeMap = {} parent_map: Dict[int, int] = {} for r in range(n_workers): - tree_map[r] = self.get_neighbor(r, n_workers) + tree_map[r] = self._get_neighbor(r, n_workers) parent_map[r] = (r + 1) // 2 - 1 return tree_map, parent_map def find_share_ring( - self, tree_map: _TreeMap, parent_map: Dict[int, int], r: int + self, tree_map: _TreeMap, parent_map: Dict[int, int], rank: int ) -> List[int]: """ get a ring structure that tends to share nodes with the tree - return a list starting from r + return a list starting from rank """ - nset = set(tree_map[r]) - cset = nset - set([parent_map[r]]) + nset = set(tree_map[rank]) + cset = nset - set([parent_map[rank]]) if not cset: - return [r] - rlst = [r] + return [rank] + rlst = [rank] cnt = 0 for v in cset: vlst = self.find_share_ring(tree_map, parent_map, v) @@ -256,7 +274,7 @@ class RabitTracker: get the link map, this is a bit hacky, call for better algorithm to place similar nodes together """ - tree_map, parent_map = self.get_tree(n_workers) + tree_map, parent_map = self._get_tree(n_workers) ring_map = self.get_ring(tree_map, parent_map) rmap = {0: 0} k = 0 @@ -279,6 +297,7 @@ class RabitTracker: return tree_map_, parent_map_, ring_map_ def accept_workers(self, n_workers: int) -> None: + """Wait for all workers to connect to the tracker.""" # set of nodes that finishes the job shutdown: Dict[int, WorkerEntry] = {} # set of nodes that is waiting for connections @@ -290,18 +309,11 @@ class RabitTracker: # lazy initialize tree_map tree_map = None - start_time = time.time() - while len(shutdown) != n_workers: fd, s_addr = self.sock.accept() s = WorkerEntry(fd, s_addr) if s.cmd == 'print': - msg = s.sock.recvstr() - # On dask we use print to avoid setting global verbosity. - if self._use_logger: - logging.info(msg.strip()) - else: - print(msg.strip(), flush=True) + s.print(self._use_logger) continue if s.cmd == 'shutdown': assert s.rank >= 0 and s.rank not in shutdown @@ -347,13 +359,9 @@ class RabitTracker: if s.wait_accept > 0: wait_conn[rank] = s logging.info('@tracker All nodes finishes job') - end_time = time.time() - logging.info( - '@tracker %s secs between node start and job finish', - str(end_time - start_time) - ) def start(self, n_workers: int) -> None: + """Strat the tracker, it will wait for `n_workers` to connect.""" def run() -> None: self.accept_workers(n_workers) @@ -361,36 +369,42 @@ class RabitTracker: self.thread.start() def join(self) -> None: + """Wait for the tracker to finish.""" while self.thread is not None and self.thread.is_alive(): self.thread.join(100) def alive(self) -> bool: + """Wether the tracker thread is alive""" return self.thread is not None and self.thread.is_alive() -def get_host_ip(hostIP: Optional[str] = None) -> str: - if hostIP is None or hostIP == 'auto': - hostIP = 'ip' +def get_host_ip(host_ip: Optional[str] = None) -> str: + """Get the IP address of current host. If `host_ip` is not none then it will be + returned as it's - if hostIP == 'dns': - hostIP = socket.getfqdn() - elif hostIP == 'ip': + """ + if host_ip is None or host_ip == 'auto': + host_ip = 'ip' + + if host_ip == 'dns': + host_ip = socket.getfqdn() + elif host_ip == 'ip': from socket import gaierror try: - hostIP = socket.gethostbyname(socket.getfqdn()) + host_ip = socket.gethostbyname(socket.getfqdn()) except gaierror: logging.debug( 'gethostbyname(socket.getfqdn()) failed... trying on hostname()' ) - hostIP = socket.gethostbyname(socket.gethostname()) - if hostIP.startswith("127."): + host_ip = socket.gethostbyname(socket.gethostname()) + if host_ip.startswith("127."): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) # doesn't have to be reachable s.connect(('10.255.255.255', 1)) - hostIP = s.getsockname()[0] + host_ip = s.getsockname()[0] - assert hostIP is not None - return hostIP + assert host_ip is not None + return host_ip def start_rabit_tracker(args: argparse.Namespace) -> None: @@ -402,7 +416,7 @@ def start_rabit_tracker(args: argparse.Namespace) -> None: """ envs = {"DMLC_NUM_WORKER": args.num_workers, "DMLC_NUM_SERVER": args.num_servers} rabit = RabitTracker( - hostIP=get_host_ip(args.host_ip), n_workers=args.num_workers, use_logger=True + host_ip=get_host_ip(args.host_ip), n_workers=args.num_workers, use_logger=True ) envs.update(rabit.worker_envs()) rabit.start(args.num_workers) diff --git a/tests/python/test_tracker.py b/tests/python/test_tracker.py index 2f19f6933..2e113898f 100644 --- a/tests/python/test_tracker.py +++ b/tests/python/test_tracker.py @@ -10,7 +10,7 @@ if sys.platform.startswith("win"): def test_rabit_tracker(): - tracker = RabitTracker(hostIP='127.0.0.1', n_workers=1) + tracker = RabitTracker(host_ip='127.0.0.1', n_workers=1) tracker.start(1) worker_env = tracker.worker_envs() rabit_env = []