Cleanup some pylint errors. (#7667)

* Cleanup some pylint errors.

* Cleanup pylint errors in rabit modules.
* Make data iter an abstract class and cleanup private access.
* Cleanup no-self-use for booster.
This commit is contained in:
Jiaming Yuan 2022-02-19 18:53:12 +08:00 committed by GitHub
parent b76c5d54bf
commit f08c5dcb06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 118 additions and 104 deletions

View File

@ -1,11 +1,9 @@
# coding: utf-8
# pylint: disable=too-many-arguments, too-many-branches, invalid-name # pylint: disable=too-many-arguments, too-many-branches, invalid-name
# pylint: disable=too-many-lines, too-many-locals, no-self-use # pylint: disable=too-many-lines, too-many-locals
"""Core XGBoost Library.""" """Core XGBoost Library."""
# pylint: disable=no-name-in-module,import-error from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from typing import List, Optional, Any, Union, Dict, TypeVar from typing import List, Optional, Any, Union, Dict, TypeVar
# pylint: enable=no-name-in-module,import-error
from typing import Callable, Tuple, cast, Sequence from typing import Callable, Tuple, cast, Sequence
import ctypes import ctypes
import os import os
@ -123,9 +121,8 @@ def _log_callback(msg: bytes) -> None:
def _get_log_callback_func() -> Callable: def _get_log_callback_func() -> Callable:
"""Wrap log_callback() method in ctypes callback type""" """Wrap log_callback() method in ctypes callback type"""
# pylint: disable=invalid-name c_callback = ctypes.CFUNCTYPE(None, ctypes.c_char_p)
CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_char_p) return c_callback(_log_callback)
return CALLBACK(_log_callback)
def _load_lib() -> ctypes.CDLL: def _load_lib() -> ctypes.CDLL:
@ -311,7 +308,7 @@ def _prediction_output(shape, dims, predts, is_cuda):
return arr_predict return arr_predict
class DataIter: # pylint: disable=too-many-instance-attributes class DataIter(ABC): # pylint: disable=too-many-instance-attributes
"""The interface for user defined data iterator. """The interface for user defined data iterator.
Parameters Parameters
@ -333,9 +330,10 @@ class DataIter: # pylint: disable=too-many-instance-attributes
# Stage data in Python until reset or next is called to avoid data being free. # Stage data in Python until reset or next is called to avoid data being free.
self._temporary_data: Optional[Tuple[Any, Any]] = None self._temporary_data: Optional[Tuple[Any, Any]] = None
def _get_callbacks( def get_callbacks(
self, allow_host: bool, enable_categorical: bool self, allow_host: bool, enable_categorical: bool
) -> Tuple[Callable, Callable]: ) -> Tuple[Callable, Callable]:
"""Get callback functions for iterating in C."""
assert hasattr(self, "cache_prefix"), "__init__ is not called." assert hasattr(self, "cache_prefix"), "__init__ is not called."
self._reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)( self._reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(
self._reset_wrapper self._reset_wrapper
@ -369,7 +367,8 @@ class DataIter: # pylint: disable=too-many-instance-attributes
self._exception = e.with_traceback(tb) self._exception = e.with_traceback(tb)
return dft_ret return dft_ret
def _reraise(self) -> None: def reraise(self) -> None:
"""Reraise the exception thrown during iteration."""
self._temporary_data = None self._temporary_data = None
if self._exception is not None: if self._exception is not None:
# pylint 2.7.0 believes `self._exception` can be None even with `assert # pylint 2.7.0 believes `self._exception` can be None even with `assert
@ -424,10 +423,12 @@ class DataIter: # pylint: disable=too-many-instance-attributes
# pylint: disable=not-callable # pylint: disable=not-callable
return self._handle_exception(lambda: self.next(data_handle), 0) return self._handle_exception(lambda: self.next(data_handle), 0)
@abstractmethod
def reset(self) -> None: def reset(self) -> None:
"""Reset the data iterator. Prototype for user defined function.""" """Reset the data iterator. Prototype for user defined function."""
raise NotImplementedError() raise NotImplementedError()
@abstractmethod
def next(self, input_data: Callable) -> int: def next(self, input_data: Callable) -> int:
"""Set the next batch of data. """Set the next batch of data.
@ -642,8 +643,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
} }
args = from_pystr_to_cstr(json.dumps(args)) args = from_pystr_to_cstr(json.dumps(args))
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
# pylint: disable=protected-access reset_callback, next_callback = it.get_callbacks(
reset_callback, next_callback = it._get_callbacks(
True, enable_categorical True, enable_categorical
) )
ret = _LIB.XGDMatrixCreateFromCallback( ret = _LIB.XGDMatrixCreateFromCallback(
@ -654,8 +654,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
args, args,
ctypes.byref(handle), ctypes.byref(handle),
) )
# pylint: disable=protected-access it.reraise()
it._reraise()
# delay check_call to throw intermediate exception first # delay check_call to throw intermediate exception first
_check_call(ret) _check_call(ret)
self.handle = handle self.handle = handle
@ -1225,8 +1224,7 @@ class DeviceQuantileDMatrix(DMatrix):
it = SingleBatchInternalIter(data=data, **meta) it = SingleBatchInternalIter(data=data, **meta)
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
# pylint: disable=protected-access reset_callback, next_callback = it.get_callbacks(False, enable_categorical)
reset_callback, next_callback = it._get_callbacks(False, enable_categorical)
if it.cache_prefix is not None: if it.cache_prefix is not None:
raise ValueError( raise ValueError(
"DeviceQuantileDMatrix doesn't cache data, remove the cache_prefix " "DeviceQuantileDMatrix doesn't cache data, remove the cache_prefix "
@ -1242,8 +1240,7 @@ class DeviceQuantileDMatrix(DMatrix):
ctypes.c_int(self.max_bin), ctypes.c_int(self.max_bin),
ctypes.byref(handle), ctypes.byref(handle),
) )
# pylint: disable=protected-access it.reraise()
it._reraise()
# delay check_call to throw intermediate exception first # delay check_call to throw intermediate exception first
_check_call(ret) _check_call(ret)
self.handle = handle self.handle = handle
@ -1281,6 +1278,21 @@ def _get_booster_layer_trees(model: "Booster") -> Tuple[int, int]:
return num_parallel_tree, num_groups return num_parallel_tree, num_groups
def _configure_metrics(params: Union[Dict, List]) -> Union[Dict, List]:
if (
isinstance(params, dict)
and "eval_metric" in params
and isinstance(params["eval_metric"], list)
):
params = dict((k, v) for k, v in params.items())
eval_metrics = params["eval_metric"]
params.pop("eval_metric", None)
params = list(params.items())
for eval_metric in eval_metrics:
params += [("eval_metric", eval_metric)]
return params
class Booster: class Booster:
# pylint: disable=too-many-public-methods # pylint: disable=too-many-public-methods
"""A Booster of XGBoost. """A Booster of XGBoost.
@ -1339,7 +1351,7 @@ class Booster:
raise TypeError('Unknown type:', model_file) raise TypeError('Unknown type:', model_file)
params = params or {} params = params or {}
params = self._configure_metrics(params.copy()) params = _configure_metrics(params.copy())
params = self._configure_constraints(params) params = self._configure_constraints(params)
if isinstance(params, list): if isinstance(params, list):
params.append(('validate_parameters', True)) params.append(('validate_parameters', True))
@ -1352,17 +1364,6 @@ class Booster:
else: else:
self.booster = 'gbtree' self.booster = 'gbtree'
def _configure_metrics(self, params: Union[Dict, List]) -> Union[Dict, List]:
if isinstance(params, dict) and 'eval_metric' in params \
and isinstance(params['eval_metric'], list):
params = dict((k, v) for k, v in params.items())
eval_metrics = params['eval_metric']
params.pop("eval_metric", None)
params = list(params.items())
for eval_metric in eval_metrics:
params += [('eval_metric', eval_metric)]
return params
def _transform_monotone_constrains(self, value: Union[Dict[str, int], str]) -> str: def _transform_monotone_constrains(self, value: Union[Dict[str, int], str]) -> str:
if isinstance(value, str): if isinstance(value, str):
return value return value
@ -1395,7 +1396,6 @@ class Booster:
) )
return s + "]" return s + "]"
except KeyError as e: except KeyError as e:
# pylint: disable=raise-missing-from
raise ValueError( raise ValueError(
"Constrained features are not a subset of training data feature names" "Constrained features are not a subset of training data feature names"
) from e ) from e

View File

@ -171,7 +171,7 @@ def _try_start_tracker(
host_ip = addrs[0][0] host_ip = addrs[0][0]
port = addrs[0][1] port = addrs[0][1]
rabit_context = RabitTracker( rabit_context = RabitTracker(
hostIP=get_host_ip(host_ip), host_ip=get_host_ip(host_ip),
n_workers=n_workers, n_workers=n_workers,
port=port, port=port,
use_logger=False, use_logger=False,
@ -179,7 +179,7 @@ def _try_start_tracker(
else: else:
assert isinstance(addrs[0], str) or addrs[0] is None assert isinstance(addrs[0], str) or addrs[0] is None
rabit_context = RabitTracker( rabit_context = RabitTracker(
hostIP=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False host_ip=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False
) )
env.update(rabit_context.worker_envs()) env.update(rabit_context.worker_envs())
rabit_context.start(n_workers) rabit_context.start(n_workers)

View File

@ -1,7 +1,6 @@
# coding: utf-8
# pylint: disable= invalid-name
"""Distributed XGBoost Rabit related API.""" """Distributed XGBoost Rabit related API."""
import ctypes import ctypes
from enum import IntEnum, unique
import pickle import pickle
from typing import Any, TypeVar, Callable, Optional, cast, List, Union from typing import Any, TypeVar, Callable, Optional, cast, List, Union
@ -98,7 +97,7 @@ def get_processor_name() -> bytes:
return buf.value return buf.value
T = TypeVar("T") T = TypeVar("T") # pylint:disable=invalid-name
def broadcast(data: T, root: int) -> T: def broadcast(data: T, root: int) -> T:
@ -152,7 +151,8 @@ DTYPE_ENUM__ = {
} }
class Op: # pylint: disable=too-few-public-methods @unique
class Op(IntEnum):
'''Supported operations for rabit.''' '''Supported operations for rabit.'''
MAX = 0 MAX = 0
MIN = 1 MIN = 1
@ -160,18 +160,18 @@ class Op: # pylint: disable=too-few-public-methods
OR = 3 OR = 3
def allreduce( def allreduce( # pylint:disable=invalid-name
data: np.ndarray, op: int, prepare_fun: Optional[Callable[[np.ndarray], None]] = None data: np.ndarray, op: Op, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
) -> np.ndarray: ) -> np.ndarray:
"""Perform allreduce, return the result. """Perform allreduce, return the result.
Parameters Parameters
---------- ----------
data: numpy array data :
Input data. Input data.
op: int op :
Reduction operators, can be MIN, MAX, SUM, BITOR Reduction operators, can be MIN, MAX, SUM, BITOR
prepare_fun: function prepare_fun :
Lazy preprocessing function, if it is not None, prepare_fun(data) Lazy preprocessing function, if it is not None, prepare_fun(data)
will be called by the function before performing allreduce, to initialize the data will be called by the function before performing allreduce, to initialize the data
If the result of Allreduce can be recovered directly, If the result of Allreduce can be recovered directly,
@ -179,7 +179,7 @@ def allreduce(
Returns Returns
------- -------
result : array_like result :
The result of allreduce, have same shape as data The result of allreduce, have same shape as data
Notes Notes
@ -196,7 +196,7 @@ def allreduce(
if prepare_fun is None: if prepare_fun is None:
_check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), _check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype], buf.size, DTYPE_ENUM__[buf.dtype],
op, None, None)) int(op), None, None))
else: else:
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p) func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)

View File

@ -1,19 +1,16 @@
# pylint: disable=too-many-instance-attributes, too-many-arguments, too-many-branches
""" """
This script is a variant of dmlc-core/dmlc_tracker/tracker.py, This script is a variant of dmlc-core/dmlc_tracker/tracker.py,
which is a specialized version for xgboost tasks. which is a specialized version for xgboost tasks.
""" """
# pylint: disable=invalid-name, missing-docstring, too-many-arguments, too-many-locals
# pylint: disable=too-many-branches, too-many-statements, too-many-instance-attributes
import socket import socket
import struct import struct
import time
import logging import logging
from threading import Thread from threading import Thread
import argparse import argparse
import sys import sys
from typing import Dict, List, Tuple, Union, Optional from typing import Dict, List, Tuple, Union, Optional, Set
_RingMap = Dict[int, Tuple[int, int]] _RingMap = Dict[int, Tuple[int, int]]
_TreeMap = Dict[int, List[int]] _TreeMap = Dict[int, List[int]]
@ -28,6 +25,7 @@ class ExSocket:
self.sock = sock self.sock = sock
def recvall(self, nbytes: int) -> bytes: def recvall(self, nbytes: int) -> bytes:
"""Receive number of bytes."""
res = [] res = []
nread = 0 nread = 0
while nread < nbytes: while nread < nbytes:
@ -37,40 +35,47 @@ class ExSocket:
return b''.join(res) return b''.join(res)
def recvint(self) -> int: def recvint(self) -> int:
"""Receive an integer of 32 bytes"""
return struct.unpack('@i', self.recvall(4))[0] return struct.unpack('@i', self.recvall(4))[0]
def sendint(self, n: int) -> None: def sendint(self, value: int) -> None:
self.sock.sendall(struct.pack('@i', n)) """Send an integer of 32 bytes"""
self.sock.sendall(struct.pack('@i', value))
def sendstr(self, s: str) -> None: def sendstr(self, value: str) -> None:
self.sendint(len(s)) """Send a Python string"""
self.sock.sendall(s.encode()) self.sendint(len(value))
self.sock.sendall(value.encode())
def recvstr(self) -> str: def recvstr(self) -> str:
"""Receive a Python string"""
slen = self.recvint() slen = self.recvint()
return self.recvall(slen).decode() return self.recvall(slen).decode()
# magic number used to verify existence of data # magic number used to verify existence of data
kMagic = 0xff99 MAGIC_NUM = 0xff99
def get_some_ip(host: str) -> str: def get_some_ip(host: str) -> str:
"""Get ip from host"""
return socket.getaddrinfo(host, None)[0][4][0] return socket.getaddrinfo(host, None)[0][4][0]
def get_family(addr: str) -> int: def get_family(addr: str) -> int:
"""Get network family from address."""
return socket.getaddrinfo(addr, None)[0][0] return socket.getaddrinfo(addr, None)[0][0]
class WorkerEntry: class WorkerEntry:
"""Hanlder to each worker."""
def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]): def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]):
worker = ExSocket(sock) worker = ExSocket(sock)
self.sock = worker self.sock = worker
self.host = get_some_ip(s_addr[0]) self.host = get_some_ip(s_addr[0])
magic = worker.recvint() magic = worker.recvint()
assert magic == kMagic, f"invalid magic number={magic} from {self.host}" assert magic == MAGIC_NUM, f"invalid magic number={magic} from {self.host}"
worker.sendint(kMagic) worker.sendint(MAGIC_NUM)
self.rank = worker.recvint() self.rank = worker.recvint()
self.world_size = worker.recvint() self.world_size = worker.recvint()
self.jobid = worker.recvstr() self.jobid = worker.recvstr()
@ -78,7 +83,17 @@ class WorkerEntry:
self.wait_accept = 0 self.wait_accept = 0
self.port: Optional[int] = None self.port: Optional[int] = None
def print(self, use_logger: bool) -> None:
"""Execute the print command from worker."""
msg = self.sock.recvstr()
# On dask we use print to avoid setting global verbosity.
if use_logger:
logging.info(msg.strip())
else:
print(msg.strip(), flush=True)
def decide_rank(self, job_map: Dict[str, int]) -> int: def decide_rank(self, job_map: Dict[str, int]) -> int:
"""Get the rank of current entry."""
if self.rank >= 0: if self.rank >= 0:
return self.rank return self.rank
if self.jobid != 'NULL' and self.jobid in job_map: if self.jobid != 'NULL' and self.jobid in job_map:
@ -93,6 +108,7 @@ class WorkerEntry:
parent_map: Dict[int, int], parent_map: Dict[int, int],
ring_map: _RingMap, ring_map: _RingMap,
) -> List[int]: ) -> List[int]:
"""Assign the rank for current entry."""
self.rank = rank self.rank = rank
nnset = set(tree_map[rank]) nnset = set(tree_map[rank])
rprev, rnext = ring_map[rank] rprev, rnext = ring_map[rank]
@ -117,6 +133,12 @@ class WorkerEntry:
self.sock.sendint(rnext) self.sock.sendint(rnext)
else: else:
self.sock.sendint(-1) self.sock.sendint(-1)
return self._get_remote(wait_conn, nnset)
def _get_remote(
self, wait_conn: Dict[int, "WorkerEntry"], nnset: Set[int]
) -> List[int]:
while True: while True:
ngood = self.sock.recvint() ngood = self.sock.recvint()
goodset = set([]) goodset = set([])
@ -158,10 +180,7 @@ class RabitTracker:
""" """
def __init__( def __init__(
self, hostIP: str, self, host_ip: str, n_workers: int, port: int = 0, use_logger: bool = False
n_workers: int,
port: int = 0,
use_logger: bool = False,
) -> None: ) -> None:
"""A Python implementation of RABIT tracker. """A Python implementation of RABIT tracker.
@ -172,23 +191,23 @@ class RabitTracker:
function is used instead. function is used instead.
""" """
sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM) sock = socket.socket(get_family(host_ip), socket.SOCK_STREAM)
sock.bind((hostIP, port)) sock.bind((host_ip, port))
self.port = sock.getsockname()[1] self.port = sock.getsockname()[1]
sock.listen(256) sock.listen(256)
self.sock = sock self.sock = sock
self.hostIP = hostIP self.host_ip = host_ip
self.thread: Optional[Thread] = None self.thread: Optional[Thread] = None
self.n_workers = n_workers self.n_workers = n_workers
self._use_logger = use_logger self._use_logger = use_logger
logging.info('start listen on %s:%d', hostIP, self.port) logging.info("start listen on %s:%d", host_ip, self.port)
def __del__(self) -> None: def __del__(self) -> None:
if hasattr(self, "sock"): if hasattr(self, "sock"):
self.sock.close() self.sock.close()
@staticmethod @staticmethod
def get_neighbor(rank: int, n_workers: int) -> List[int]: def _get_neighbor(rank: int, n_workers: int) -> List[int]:
rank = rank + 1 rank = rank + 1
ret = [] ret = []
if rank > 1: if rank > 1:
@ -204,29 +223,28 @@ class RabitTracker:
get environment variables for workers get environment variables for workers
can be passed in as args or envs can be passed in as args or envs
""" """
return {'DMLC_TRACKER_URI': self.hostIP, return {'DMLC_TRACKER_URI': self.host_ip, 'DMLC_TRACKER_PORT': self.port}
'DMLC_TRACKER_PORT': self.port}
def get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]: def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]:
tree_map: _TreeMap = {} tree_map: _TreeMap = {}
parent_map: Dict[int, int] = {} parent_map: Dict[int, int] = {}
for r in range(n_workers): for r in range(n_workers):
tree_map[r] = self.get_neighbor(r, n_workers) tree_map[r] = self._get_neighbor(r, n_workers)
parent_map[r] = (r + 1) // 2 - 1 parent_map[r] = (r + 1) // 2 - 1
return tree_map, parent_map return tree_map, parent_map
def find_share_ring( def find_share_ring(
self, tree_map: _TreeMap, parent_map: Dict[int, int], r: int self, tree_map: _TreeMap, parent_map: Dict[int, int], rank: int
) -> List[int]: ) -> List[int]:
""" """
get a ring structure that tends to share nodes with the tree get a ring structure that tends to share nodes with the tree
return a list starting from r return a list starting from rank
""" """
nset = set(tree_map[r]) nset = set(tree_map[rank])
cset = nset - set([parent_map[r]]) cset = nset - set([parent_map[rank]])
if not cset: if not cset:
return [r] return [rank]
rlst = [r] rlst = [rank]
cnt = 0 cnt = 0
for v in cset: for v in cset:
vlst = self.find_share_ring(tree_map, parent_map, v) vlst = self.find_share_ring(tree_map, parent_map, v)
@ -256,7 +274,7 @@ class RabitTracker:
get the link map, this is a bit hacky, call for better algorithm get the link map, this is a bit hacky, call for better algorithm
to place similar nodes together to place similar nodes together
""" """
tree_map, parent_map = self.get_tree(n_workers) tree_map, parent_map = self._get_tree(n_workers)
ring_map = self.get_ring(tree_map, parent_map) ring_map = self.get_ring(tree_map, parent_map)
rmap = {0: 0} rmap = {0: 0}
k = 0 k = 0
@ -279,6 +297,7 @@ class RabitTracker:
return tree_map_, parent_map_, ring_map_ return tree_map_, parent_map_, ring_map_
def accept_workers(self, n_workers: int) -> None: def accept_workers(self, n_workers: int) -> None:
"""Wait for all workers to connect to the tracker."""
# set of nodes that finishes the job # set of nodes that finishes the job
shutdown: Dict[int, WorkerEntry] = {} shutdown: Dict[int, WorkerEntry] = {}
# set of nodes that is waiting for connections # set of nodes that is waiting for connections
@ -290,18 +309,11 @@ class RabitTracker:
# lazy initialize tree_map # lazy initialize tree_map
tree_map = None tree_map = None
start_time = time.time()
while len(shutdown) != n_workers: while len(shutdown) != n_workers:
fd, s_addr = self.sock.accept() fd, s_addr = self.sock.accept()
s = WorkerEntry(fd, s_addr) s = WorkerEntry(fd, s_addr)
if s.cmd == 'print': if s.cmd == 'print':
msg = s.sock.recvstr() s.print(self._use_logger)
# On dask we use print to avoid setting global verbosity.
if self._use_logger:
logging.info(msg.strip())
else:
print(msg.strip(), flush=True)
continue continue
if s.cmd == 'shutdown': if s.cmd == 'shutdown':
assert s.rank >= 0 and s.rank not in shutdown assert s.rank >= 0 and s.rank not in shutdown
@ -347,13 +359,9 @@ class RabitTracker:
if s.wait_accept > 0: if s.wait_accept > 0:
wait_conn[rank] = s wait_conn[rank] = s
logging.info('@tracker All nodes finishes job') logging.info('@tracker All nodes finishes job')
end_time = time.time()
logging.info(
'@tracker %s secs between node start and job finish',
str(end_time - start_time)
)
def start(self, n_workers: int) -> None: def start(self, n_workers: int) -> None:
"""Strat the tracker, it will wait for `n_workers` to connect."""
def run() -> None: def run() -> None:
self.accept_workers(n_workers) self.accept_workers(n_workers)
@ -361,36 +369,42 @@ class RabitTracker:
self.thread.start() self.thread.start()
def join(self) -> None: def join(self) -> None:
"""Wait for the tracker to finish."""
while self.thread is not None and self.thread.is_alive(): while self.thread is not None and self.thread.is_alive():
self.thread.join(100) self.thread.join(100)
def alive(self) -> bool: def alive(self) -> bool:
"""Wether the tracker thread is alive"""
return self.thread is not None and self.thread.is_alive() return self.thread is not None and self.thread.is_alive()
def get_host_ip(hostIP: Optional[str] = None) -> str: def get_host_ip(host_ip: Optional[str] = None) -> str:
if hostIP is None or hostIP == 'auto': """Get the IP address of current host. If `host_ip` is not none then it will be
hostIP = 'ip' returned as it's
if hostIP == 'dns': """
hostIP = socket.getfqdn() if host_ip is None or host_ip == 'auto':
elif hostIP == 'ip': host_ip = 'ip'
if host_ip == 'dns':
host_ip = socket.getfqdn()
elif host_ip == 'ip':
from socket import gaierror from socket import gaierror
try: try:
hostIP = socket.gethostbyname(socket.getfqdn()) host_ip = socket.gethostbyname(socket.getfqdn())
except gaierror: except gaierror:
logging.debug( logging.debug(
'gethostbyname(socket.getfqdn()) failed... trying on hostname()' 'gethostbyname(socket.getfqdn()) failed... trying on hostname()'
) )
hostIP = socket.gethostbyname(socket.gethostname()) host_ip = socket.gethostbyname(socket.gethostname())
if hostIP.startswith("127."): if host_ip.startswith("127."):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# doesn't have to be reachable # doesn't have to be reachable
s.connect(('10.255.255.255', 1)) s.connect(('10.255.255.255', 1))
hostIP = s.getsockname()[0] host_ip = s.getsockname()[0]
assert hostIP is not None assert host_ip is not None
return hostIP return host_ip
def start_rabit_tracker(args: argparse.Namespace) -> None: def start_rabit_tracker(args: argparse.Namespace) -> None:
@ -402,7 +416,7 @@ def start_rabit_tracker(args: argparse.Namespace) -> None:
""" """
envs = {"DMLC_NUM_WORKER": args.num_workers, "DMLC_NUM_SERVER": args.num_servers} envs = {"DMLC_NUM_WORKER": args.num_workers, "DMLC_NUM_SERVER": args.num_servers}
rabit = RabitTracker( rabit = RabitTracker(
hostIP=get_host_ip(args.host_ip), n_workers=args.num_workers, use_logger=True host_ip=get_host_ip(args.host_ip), n_workers=args.num_workers, use_logger=True
) )
envs.update(rabit.worker_envs()) envs.update(rabit.worker_envs())
rabit.start(args.num_workers) rabit.start(args.num_workers)

View File

@ -10,7 +10,7 @@ if sys.platform.startswith("win"):
def test_rabit_tracker(): def test_rabit_tracker():
tracker = RabitTracker(hostIP='127.0.0.1', n_workers=1) tracker = RabitTracker(host_ip='127.0.0.1', n_workers=1)
tracker.start(1) tracker.start(1)
worker_env = tracker.worker_envs() worker_env = tracker.worker_envs()
rabit_env = [] rabit_env = []