Cleanup some pylint errors. (#7667)

* Cleanup some pylint errors.

* Cleanup pylint errors in rabit modules.
* Make data iter an abstract class and cleanup private access.
* Cleanup no-self-use for booster.
This commit is contained in:
Jiaming Yuan 2022-02-19 18:53:12 +08:00 committed by GitHub
parent b76c5d54bf
commit f08c5dcb06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 118 additions and 104 deletions

View File

@ -1,11 +1,9 @@
# coding: utf-8
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
# pylint: disable=too-many-lines, too-many-locals, no-self-use
# pylint: disable=too-many-lines, too-many-locals
"""Core XGBoost Library."""
# pylint: disable=no-name-in-module,import-error
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import List, Optional, Any, Union, Dict, TypeVar
# pylint: enable=no-name-in-module,import-error
from typing import Callable, Tuple, cast, Sequence
import ctypes
import os
@ -123,9 +121,8 @@ def _log_callback(msg: bytes) -> None:
def _get_log_callback_func() -> Callable:
"""Wrap log_callback() method in ctypes callback type"""
# pylint: disable=invalid-name
CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_char_p)
return CALLBACK(_log_callback)
c_callback = ctypes.CFUNCTYPE(None, ctypes.c_char_p)
return c_callback(_log_callback)
def _load_lib() -> ctypes.CDLL:
@ -311,7 +308,7 @@ def _prediction_output(shape, dims, predts, is_cuda):
return arr_predict
class DataIter: # pylint: disable=too-many-instance-attributes
class DataIter(ABC): # pylint: disable=too-many-instance-attributes
"""The interface for user defined data iterator.
Parameters
@ -333,9 +330,10 @@ class DataIter: # pylint: disable=too-many-instance-attributes
# Stage data in Python until reset or next is called to avoid data being free.
self._temporary_data: Optional[Tuple[Any, Any]] = None
def _get_callbacks(
def get_callbacks(
self, allow_host: bool, enable_categorical: bool
) -> Tuple[Callable, Callable]:
"""Get callback functions for iterating in C."""
assert hasattr(self, "cache_prefix"), "__init__ is not called."
self._reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(
self._reset_wrapper
@ -369,7 +367,8 @@ class DataIter: # pylint: disable=too-many-instance-attributes
self._exception = e.with_traceback(tb)
return dft_ret
def _reraise(self) -> None:
def reraise(self) -> None:
"""Reraise the exception thrown during iteration."""
self._temporary_data = None
if self._exception is not None:
# pylint 2.7.0 believes `self._exception` can be None even with `assert
@ -424,10 +423,12 @@ class DataIter: # pylint: disable=too-many-instance-attributes
# pylint: disable=not-callable
return self._handle_exception(lambda: self.next(data_handle), 0)
@abstractmethod
def reset(self) -> None:
"""Reset the data iterator. Prototype for user defined function."""
raise NotImplementedError()
@abstractmethod
def next(self, input_data: Callable) -> int:
"""Set the next batch of data.
@ -642,8 +643,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
}
args = from_pystr_to_cstr(json.dumps(args))
handle = ctypes.c_void_p()
# pylint: disable=protected-access
reset_callback, next_callback = it._get_callbacks(
reset_callback, next_callback = it.get_callbacks(
True, enable_categorical
)
ret = _LIB.XGDMatrixCreateFromCallback(
@ -654,8 +654,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
args,
ctypes.byref(handle),
)
# pylint: disable=protected-access
it._reraise()
it.reraise()
# delay check_call to throw intermediate exception first
_check_call(ret)
self.handle = handle
@ -1225,8 +1224,7 @@ class DeviceQuantileDMatrix(DMatrix):
it = SingleBatchInternalIter(data=data, **meta)
handle = ctypes.c_void_p()
# pylint: disable=protected-access
reset_callback, next_callback = it._get_callbacks(False, enable_categorical)
reset_callback, next_callback = it.get_callbacks(False, enable_categorical)
if it.cache_prefix is not None:
raise ValueError(
"DeviceQuantileDMatrix doesn't cache data, remove the cache_prefix "
@ -1242,8 +1240,7 @@ class DeviceQuantileDMatrix(DMatrix):
ctypes.c_int(self.max_bin),
ctypes.byref(handle),
)
# pylint: disable=protected-access
it._reraise()
it.reraise()
# delay check_call to throw intermediate exception first
_check_call(ret)
self.handle = handle
@ -1281,6 +1278,21 @@ def _get_booster_layer_trees(model: "Booster") -> Tuple[int, int]:
return num_parallel_tree, num_groups
def _configure_metrics(params: Union[Dict, List]) -> Union[Dict, List]:
if (
isinstance(params, dict)
and "eval_metric" in params
and isinstance(params["eval_metric"], list)
):
params = dict((k, v) for k, v in params.items())
eval_metrics = params["eval_metric"]
params.pop("eval_metric", None)
params = list(params.items())
for eval_metric in eval_metrics:
params += [("eval_metric", eval_metric)]
return params
class Booster:
# pylint: disable=too-many-public-methods
"""A Booster of XGBoost.
@ -1339,7 +1351,7 @@ class Booster:
raise TypeError('Unknown type:', model_file)
params = params or {}
params = self._configure_metrics(params.copy())
params = _configure_metrics(params.copy())
params = self._configure_constraints(params)
if isinstance(params, list):
params.append(('validate_parameters', True))
@ -1352,17 +1364,6 @@ class Booster:
else:
self.booster = 'gbtree'
def _configure_metrics(self, params: Union[Dict, List]) -> Union[Dict, List]:
if isinstance(params, dict) and 'eval_metric' in params \
and isinstance(params['eval_metric'], list):
params = dict((k, v) for k, v in params.items())
eval_metrics = params['eval_metric']
params.pop("eval_metric", None)
params = list(params.items())
for eval_metric in eval_metrics:
params += [('eval_metric', eval_metric)]
return params
def _transform_monotone_constrains(self, value: Union[Dict[str, int], str]) -> str:
if isinstance(value, str):
return value
@ -1395,7 +1396,6 @@ class Booster:
)
return s + "]"
except KeyError as e:
# pylint: disable=raise-missing-from
raise ValueError(
"Constrained features are not a subset of training data feature names"
) from e

View File

@ -171,7 +171,7 @@ def _try_start_tracker(
host_ip = addrs[0][0]
port = addrs[0][1]
rabit_context = RabitTracker(
hostIP=get_host_ip(host_ip),
host_ip=get_host_ip(host_ip),
n_workers=n_workers,
port=port,
use_logger=False,
@ -179,7 +179,7 @@ def _try_start_tracker(
else:
assert isinstance(addrs[0], str) or addrs[0] is None
rabit_context = RabitTracker(
hostIP=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False
host_ip=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False
)
env.update(rabit_context.worker_envs())
rabit_context.start(n_workers)

View File

@ -1,7 +1,6 @@
# coding: utf-8
# pylint: disable= invalid-name
"""Distributed XGBoost Rabit related API."""
import ctypes
from enum import IntEnum, unique
import pickle
from typing import Any, TypeVar, Callable, Optional, cast, List, Union
@ -98,7 +97,7 @@ def get_processor_name() -> bytes:
return buf.value
T = TypeVar("T")
T = TypeVar("T") # pylint:disable=invalid-name
def broadcast(data: T, root: int) -> T:
@ -152,7 +151,8 @@ DTYPE_ENUM__ = {
}
class Op: # pylint: disable=too-few-public-methods
@unique
class Op(IntEnum):
'''Supported operations for rabit.'''
MAX = 0
MIN = 1
@ -160,18 +160,18 @@ class Op: # pylint: disable=too-few-public-methods
OR = 3
def allreduce(
data: np.ndarray, op: int, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
def allreduce( # pylint:disable=invalid-name
data: np.ndarray, op: Op, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
) -> np.ndarray:
"""Perform allreduce, return the result.
Parameters
----------
data: numpy array
data :
Input data.
op: int
op :
Reduction operators, can be MIN, MAX, SUM, BITOR
prepare_fun: function
prepare_fun :
Lazy preprocessing function, if it is not None, prepare_fun(data)
will be called by the function before performing allreduce, to initialize the data
If the result of Allreduce can be recovered directly,
@ -179,7 +179,7 @@ def allreduce(
Returns
-------
result : array_like
result :
The result of allreduce, have same shape as data
Notes
@ -196,7 +196,7 @@ def allreduce(
if prepare_fun is None:
_check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype],
op, None, None))
int(op), None, None))
else:
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)

View File

@ -1,19 +1,16 @@
# pylint: disable=too-many-instance-attributes, too-many-arguments, too-many-branches
"""
This script is a variant of dmlc-core/dmlc_tracker/tracker.py,
which is a specialized version for xgboost tasks.
"""
# pylint: disable=invalid-name, missing-docstring, too-many-arguments, too-many-locals
# pylint: disable=too-many-branches, too-many-statements, too-many-instance-attributes
import socket
import struct
import time
import logging
from threading import Thread
import argparse
import sys
from typing import Dict, List, Tuple, Union, Optional
from typing import Dict, List, Tuple, Union, Optional, Set
_RingMap = Dict[int, Tuple[int, int]]
_TreeMap = Dict[int, List[int]]
@ -28,6 +25,7 @@ class ExSocket:
self.sock = sock
def recvall(self, nbytes: int) -> bytes:
"""Receive number of bytes."""
res = []
nread = 0
while nread < nbytes:
@ -37,40 +35,47 @@ class ExSocket:
return b''.join(res)
def recvint(self) -> int:
"""Receive an integer of 32 bytes"""
return struct.unpack('@i', self.recvall(4))[0]
def sendint(self, n: int) -> None:
self.sock.sendall(struct.pack('@i', n))
def sendint(self, value: int) -> None:
"""Send an integer of 32 bytes"""
self.sock.sendall(struct.pack('@i', value))
def sendstr(self, s: str) -> None:
self.sendint(len(s))
self.sock.sendall(s.encode())
def sendstr(self, value: str) -> None:
"""Send a Python string"""
self.sendint(len(value))
self.sock.sendall(value.encode())
def recvstr(self) -> str:
"""Receive a Python string"""
slen = self.recvint()
return self.recvall(slen).decode()
# magic number used to verify existence of data
kMagic = 0xff99
MAGIC_NUM = 0xff99
def get_some_ip(host: str) -> str:
"""Get ip from host"""
return socket.getaddrinfo(host, None)[0][4][0]
def get_family(addr: str) -> int:
"""Get network family from address."""
return socket.getaddrinfo(addr, None)[0][0]
class WorkerEntry:
"""Hanlder to each worker."""
def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]):
worker = ExSocket(sock)
self.sock = worker
self.host = get_some_ip(s_addr[0])
magic = worker.recvint()
assert magic == kMagic, f"invalid magic number={magic} from {self.host}"
worker.sendint(kMagic)
assert magic == MAGIC_NUM, f"invalid magic number={magic} from {self.host}"
worker.sendint(MAGIC_NUM)
self.rank = worker.recvint()
self.world_size = worker.recvint()
self.jobid = worker.recvstr()
@ -78,7 +83,17 @@ class WorkerEntry:
self.wait_accept = 0
self.port: Optional[int] = None
def print(self, use_logger: bool) -> None:
"""Execute the print command from worker."""
msg = self.sock.recvstr()
# On dask we use print to avoid setting global verbosity.
if use_logger:
logging.info(msg.strip())
else:
print(msg.strip(), flush=True)
def decide_rank(self, job_map: Dict[str, int]) -> int:
"""Get the rank of current entry."""
if self.rank >= 0:
return self.rank
if self.jobid != 'NULL' and self.jobid in job_map:
@ -93,6 +108,7 @@ class WorkerEntry:
parent_map: Dict[int, int],
ring_map: _RingMap,
) -> List[int]:
"""Assign the rank for current entry."""
self.rank = rank
nnset = set(tree_map[rank])
rprev, rnext = ring_map[rank]
@ -117,6 +133,12 @@ class WorkerEntry:
self.sock.sendint(rnext)
else:
self.sock.sendint(-1)
return self._get_remote(wait_conn, nnset)
def _get_remote(
self, wait_conn: Dict[int, "WorkerEntry"], nnset: Set[int]
) -> List[int]:
while True:
ngood = self.sock.recvint()
goodset = set([])
@ -158,10 +180,7 @@ class RabitTracker:
"""
def __init__(
self, hostIP: str,
n_workers: int,
port: int = 0,
use_logger: bool = False,
self, host_ip: str, n_workers: int, port: int = 0, use_logger: bool = False
) -> None:
"""A Python implementation of RABIT tracker.
@ -172,23 +191,23 @@ class RabitTracker:
function is used instead.
"""
sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM)
sock.bind((hostIP, port))
sock = socket.socket(get_family(host_ip), socket.SOCK_STREAM)
sock.bind((host_ip, port))
self.port = sock.getsockname()[1]
sock.listen(256)
self.sock = sock
self.hostIP = hostIP
self.host_ip = host_ip
self.thread: Optional[Thread] = None
self.n_workers = n_workers
self._use_logger = use_logger
logging.info('start listen on %s:%d', hostIP, self.port)
logging.info("start listen on %s:%d", host_ip, self.port)
def __del__(self) -> None:
if hasattr(self, "sock"):
self.sock.close()
@staticmethod
def get_neighbor(rank: int, n_workers: int) -> List[int]:
def _get_neighbor(rank: int, n_workers: int) -> List[int]:
rank = rank + 1
ret = []
if rank > 1:
@ -204,29 +223,28 @@ class RabitTracker:
get environment variables for workers
can be passed in as args or envs
"""
return {'DMLC_TRACKER_URI': self.hostIP,
'DMLC_TRACKER_PORT': self.port}
return {'DMLC_TRACKER_URI': self.host_ip, 'DMLC_TRACKER_PORT': self.port}
def get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]:
def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]:
tree_map: _TreeMap = {}
parent_map: Dict[int, int] = {}
for r in range(n_workers):
tree_map[r] = self.get_neighbor(r, n_workers)
tree_map[r] = self._get_neighbor(r, n_workers)
parent_map[r] = (r + 1) // 2 - 1
return tree_map, parent_map
def find_share_ring(
self, tree_map: _TreeMap, parent_map: Dict[int, int], r: int
self, tree_map: _TreeMap, parent_map: Dict[int, int], rank: int
) -> List[int]:
"""
get a ring structure that tends to share nodes with the tree
return a list starting from r
return a list starting from rank
"""
nset = set(tree_map[r])
cset = nset - set([parent_map[r]])
nset = set(tree_map[rank])
cset = nset - set([parent_map[rank]])
if not cset:
return [r]
rlst = [r]
return [rank]
rlst = [rank]
cnt = 0
for v in cset:
vlst = self.find_share_ring(tree_map, parent_map, v)
@ -256,7 +274,7 @@ class RabitTracker:
get the link map, this is a bit hacky, call for better algorithm
to place similar nodes together
"""
tree_map, parent_map = self.get_tree(n_workers)
tree_map, parent_map = self._get_tree(n_workers)
ring_map = self.get_ring(tree_map, parent_map)
rmap = {0: 0}
k = 0
@ -279,6 +297,7 @@ class RabitTracker:
return tree_map_, parent_map_, ring_map_
def accept_workers(self, n_workers: int) -> None:
"""Wait for all workers to connect to the tracker."""
# set of nodes that finishes the job
shutdown: Dict[int, WorkerEntry] = {}
# set of nodes that is waiting for connections
@ -290,18 +309,11 @@ class RabitTracker:
# lazy initialize tree_map
tree_map = None
start_time = time.time()
while len(shutdown) != n_workers:
fd, s_addr = self.sock.accept()
s = WorkerEntry(fd, s_addr)
if s.cmd == 'print':
msg = s.sock.recvstr()
# On dask we use print to avoid setting global verbosity.
if self._use_logger:
logging.info(msg.strip())
else:
print(msg.strip(), flush=True)
s.print(self._use_logger)
continue
if s.cmd == 'shutdown':
assert s.rank >= 0 and s.rank not in shutdown
@ -347,13 +359,9 @@ class RabitTracker:
if s.wait_accept > 0:
wait_conn[rank] = s
logging.info('@tracker All nodes finishes job')
end_time = time.time()
logging.info(
'@tracker %s secs between node start and job finish',
str(end_time - start_time)
)
def start(self, n_workers: int) -> None:
"""Strat the tracker, it will wait for `n_workers` to connect."""
def run() -> None:
self.accept_workers(n_workers)
@ -361,36 +369,42 @@ class RabitTracker:
self.thread.start()
def join(self) -> None:
"""Wait for the tracker to finish."""
while self.thread is not None and self.thread.is_alive():
self.thread.join(100)
def alive(self) -> bool:
"""Wether the tracker thread is alive"""
return self.thread is not None and self.thread.is_alive()
def get_host_ip(hostIP: Optional[str] = None) -> str:
if hostIP is None or hostIP == 'auto':
hostIP = 'ip'
def get_host_ip(host_ip: Optional[str] = None) -> str:
"""Get the IP address of current host. If `host_ip` is not none then it will be
returned as it's
if hostIP == 'dns':
hostIP = socket.getfqdn()
elif hostIP == 'ip':
"""
if host_ip is None or host_ip == 'auto':
host_ip = 'ip'
if host_ip == 'dns':
host_ip = socket.getfqdn()
elif host_ip == 'ip':
from socket import gaierror
try:
hostIP = socket.gethostbyname(socket.getfqdn())
host_ip = socket.gethostbyname(socket.getfqdn())
except gaierror:
logging.debug(
'gethostbyname(socket.getfqdn()) failed... trying on hostname()'
)
hostIP = socket.gethostbyname(socket.gethostname())
if hostIP.startswith("127."):
host_ip = socket.gethostbyname(socket.gethostname())
if host_ip.startswith("127."):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# doesn't have to be reachable
s.connect(('10.255.255.255', 1))
hostIP = s.getsockname()[0]
host_ip = s.getsockname()[0]
assert hostIP is not None
return hostIP
assert host_ip is not None
return host_ip
def start_rabit_tracker(args: argparse.Namespace) -> None:
@ -402,7 +416,7 @@ def start_rabit_tracker(args: argparse.Namespace) -> None:
"""
envs = {"DMLC_NUM_WORKER": args.num_workers, "DMLC_NUM_SERVER": args.num_servers}
rabit = RabitTracker(
hostIP=get_host_ip(args.host_ip), n_workers=args.num_workers, use_logger=True
host_ip=get_host_ip(args.host_ip), n_workers=args.num_workers, use_logger=True
)
envs.update(rabit.worker_envs())
rabit.start(args.num_workers)

View File

@ -10,7 +10,7 @@ if sys.platform.startswith("win"):
def test_rabit_tracker():
tracker = RabitTracker(hostIP='127.0.0.1', n_workers=1)
tracker = RabitTracker(host_ip='127.0.0.1', n_workers=1)
tracker.start(1)
worker_env = tracker.worker_envs()
rabit_env = []