Cleanup some pylint errors. (#7667)
* Cleanup some pylint errors. * Cleanup pylint errors in rabit modules. * Make data iter an abstract class and cleanup private access. * Cleanup no-self-use for booster.
This commit is contained in:
parent
b76c5d54bf
commit
f08c5dcb06
@ -1,11 +1,9 @@
|
||||
# coding: utf-8
|
||||
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
|
||||
# pylint: disable=too-many-lines, too-many-locals, no-self-use
|
||||
# pylint: disable=too-many-lines, too-many-locals
|
||||
"""Core XGBoost Library."""
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from typing import List, Optional, Any, Union, Dict, TypeVar
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from typing import Callable, Tuple, cast, Sequence
|
||||
import ctypes
|
||||
import os
|
||||
@ -123,9 +121,8 @@ def _log_callback(msg: bytes) -> None:
|
||||
|
||||
def _get_log_callback_func() -> Callable:
|
||||
"""Wrap log_callback() method in ctypes callback type"""
|
||||
# pylint: disable=invalid-name
|
||||
CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_char_p)
|
||||
return CALLBACK(_log_callback)
|
||||
c_callback = ctypes.CFUNCTYPE(None, ctypes.c_char_p)
|
||||
return c_callback(_log_callback)
|
||||
|
||||
|
||||
def _load_lib() -> ctypes.CDLL:
|
||||
@ -311,7 +308,7 @@ def _prediction_output(shape, dims, predts, is_cuda):
|
||||
return arr_predict
|
||||
|
||||
|
||||
class DataIter: # pylint: disable=too-many-instance-attributes
|
||||
class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
"""The interface for user defined data iterator.
|
||||
|
||||
Parameters
|
||||
@ -333,9 +330,10 @@ class DataIter: # pylint: disable=too-many-instance-attributes
|
||||
# Stage data in Python until reset or next is called to avoid data being free.
|
||||
self._temporary_data: Optional[Tuple[Any, Any]] = None
|
||||
|
||||
def _get_callbacks(
|
||||
def get_callbacks(
|
||||
self, allow_host: bool, enable_categorical: bool
|
||||
) -> Tuple[Callable, Callable]:
|
||||
"""Get callback functions for iterating in C."""
|
||||
assert hasattr(self, "cache_prefix"), "__init__ is not called."
|
||||
self._reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(
|
||||
self._reset_wrapper
|
||||
@ -369,7 +367,8 @@ class DataIter: # pylint: disable=too-many-instance-attributes
|
||||
self._exception = e.with_traceback(tb)
|
||||
return dft_ret
|
||||
|
||||
def _reraise(self) -> None:
|
||||
def reraise(self) -> None:
|
||||
"""Reraise the exception thrown during iteration."""
|
||||
self._temporary_data = None
|
||||
if self._exception is not None:
|
||||
# pylint 2.7.0 believes `self._exception` can be None even with `assert
|
||||
@ -424,10 +423,12 @@ class DataIter: # pylint: disable=too-many-instance-attributes
|
||||
# pylint: disable=not-callable
|
||||
return self._handle_exception(lambda: self.next(data_handle), 0)
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the data iterator. Prototype for user defined function."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def next(self, input_data: Callable) -> int:
|
||||
"""Set the next batch of data.
|
||||
|
||||
@ -642,8 +643,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
}
|
||||
args = from_pystr_to_cstr(json.dumps(args))
|
||||
handle = ctypes.c_void_p()
|
||||
# pylint: disable=protected-access
|
||||
reset_callback, next_callback = it._get_callbacks(
|
||||
reset_callback, next_callback = it.get_callbacks(
|
||||
True, enable_categorical
|
||||
)
|
||||
ret = _LIB.XGDMatrixCreateFromCallback(
|
||||
@ -654,8 +654,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
args,
|
||||
ctypes.byref(handle),
|
||||
)
|
||||
# pylint: disable=protected-access
|
||||
it._reraise()
|
||||
it.reraise()
|
||||
# delay check_call to throw intermediate exception first
|
||||
_check_call(ret)
|
||||
self.handle = handle
|
||||
@ -1225,8 +1224,7 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
it = SingleBatchInternalIter(data=data, **meta)
|
||||
|
||||
handle = ctypes.c_void_p()
|
||||
# pylint: disable=protected-access
|
||||
reset_callback, next_callback = it._get_callbacks(False, enable_categorical)
|
||||
reset_callback, next_callback = it.get_callbacks(False, enable_categorical)
|
||||
if it.cache_prefix is not None:
|
||||
raise ValueError(
|
||||
"DeviceQuantileDMatrix doesn't cache data, remove the cache_prefix "
|
||||
@ -1242,8 +1240,7 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
ctypes.c_int(self.max_bin),
|
||||
ctypes.byref(handle),
|
||||
)
|
||||
# pylint: disable=protected-access
|
||||
it._reraise()
|
||||
it.reraise()
|
||||
# delay check_call to throw intermediate exception first
|
||||
_check_call(ret)
|
||||
self.handle = handle
|
||||
@ -1281,6 +1278,21 @@ def _get_booster_layer_trees(model: "Booster") -> Tuple[int, int]:
|
||||
return num_parallel_tree, num_groups
|
||||
|
||||
|
||||
def _configure_metrics(params: Union[Dict, List]) -> Union[Dict, List]:
|
||||
if (
|
||||
isinstance(params, dict)
|
||||
and "eval_metric" in params
|
||||
and isinstance(params["eval_metric"], list)
|
||||
):
|
||||
params = dict((k, v) for k, v in params.items())
|
||||
eval_metrics = params["eval_metric"]
|
||||
params.pop("eval_metric", None)
|
||||
params = list(params.items())
|
||||
for eval_metric in eval_metrics:
|
||||
params += [("eval_metric", eval_metric)]
|
||||
return params
|
||||
|
||||
|
||||
class Booster:
|
||||
# pylint: disable=too-many-public-methods
|
||||
"""A Booster of XGBoost.
|
||||
@ -1339,7 +1351,7 @@ class Booster:
|
||||
raise TypeError('Unknown type:', model_file)
|
||||
|
||||
params = params or {}
|
||||
params = self._configure_metrics(params.copy())
|
||||
params = _configure_metrics(params.copy())
|
||||
params = self._configure_constraints(params)
|
||||
if isinstance(params, list):
|
||||
params.append(('validate_parameters', True))
|
||||
@ -1352,17 +1364,6 @@ class Booster:
|
||||
else:
|
||||
self.booster = 'gbtree'
|
||||
|
||||
def _configure_metrics(self, params: Union[Dict, List]) -> Union[Dict, List]:
|
||||
if isinstance(params, dict) and 'eval_metric' in params \
|
||||
and isinstance(params['eval_metric'], list):
|
||||
params = dict((k, v) for k, v in params.items())
|
||||
eval_metrics = params['eval_metric']
|
||||
params.pop("eval_metric", None)
|
||||
params = list(params.items())
|
||||
for eval_metric in eval_metrics:
|
||||
params += [('eval_metric', eval_metric)]
|
||||
return params
|
||||
|
||||
def _transform_monotone_constrains(self, value: Union[Dict[str, int], str]) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
@ -1395,7 +1396,6 @@ class Booster:
|
||||
)
|
||||
return s + "]"
|
||||
except KeyError as e:
|
||||
# pylint: disable=raise-missing-from
|
||||
raise ValueError(
|
||||
"Constrained features are not a subset of training data feature names"
|
||||
) from e
|
||||
|
||||
@ -171,7 +171,7 @@ def _try_start_tracker(
|
||||
host_ip = addrs[0][0]
|
||||
port = addrs[0][1]
|
||||
rabit_context = RabitTracker(
|
||||
hostIP=get_host_ip(host_ip),
|
||||
host_ip=get_host_ip(host_ip),
|
||||
n_workers=n_workers,
|
||||
port=port,
|
||||
use_logger=False,
|
||||
@ -179,7 +179,7 @@ def _try_start_tracker(
|
||||
else:
|
||||
assert isinstance(addrs[0], str) or addrs[0] is None
|
||||
rabit_context = RabitTracker(
|
||||
hostIP=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False
|
||||
host_ip=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False
|
||||
)
|
||||
env.update(rabit_context.worker_envs())
|
||||
rabit_context.start(n_workers)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# coding: utf-8
|
||||
# pylint: disable= invalid-name
|
||||
"""Distributed XGBoost Rabit related API."""
|
||||
import ctypes
|
||||
from enum import IntEnum, unique
|
||||
import pickle
|
||||
from typing import Any, TypeVar, Callable, Optional, cast, List, Union
|
||||
|
||||
@ -98,7 +97,7 @@ def get_processor_name() -> bytes:
|
||||
return buf.value
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
T = TypeVar("T") # pylint:disable=invalid-name
|
||||
|
||||
|
||||
def broadcast(data: T, root: int) -> T:
|
||||
@ -152,7 +151,8 @@ DTYPE_ENUM__ = {
|
||||
}
|
||||
|
||||
|
||||
class Op: # pylint: disable=too-few-public-methods
|
||||
@unique
|
||||
class Op(IntEnum):
|
||||
'''Supported operations for rabit.'''
|
||||
MAX = 0
|
||||
MIN = 1
|
||||
@ -160,18 +160,18 @@ class Op: # pylint: disable=too-few-public-methods
|
||||
OR = 3
|
||||
|
||||
|
||||
def allreduce(
|
||||
data: np.ndarray, op: int, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
|
||||
def allreduce( # pylint:disable=invalid-name
|
||||
data: np.ndarray, op: Op, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
|
||||
) -> np.ndarray:
|
||||
"""Perform allreduce, return the result.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: numpy array
|
||||
data :
|
||||
Input data.
|
||||
op: int
|
||||
op :
|
||||
Reduction operators, can be MIN, MAX, SUM, BITOR
|
||||
prepare_fun: function
|
||||
prepare_fun :
|
||||
Lazy preprocessing function, if it is not None, prepare_fun(data)
|
||||
will be called by the function before performing allreduce, to initialize the data
|
||||
If the result of Allreduce can be recovered directly,
|
||||
@ -179,7 +179,7 @@ def allreduce(
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : array_like
|
||||
result :
|
||||
The result of allreduce, have same shape as data
|
||||
|
||||
Notes
|
||||
@ -196,7 +196,7 @@ def allreduce(
|
||||
if prepare_fun is None:
|
||||
_check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||
op, None, None))
|
||||
int(op), None, None))
|
||||
else:
|
||||
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
||||
|
||||
|
||||
@ -1,19 +1,16 @@
|
||||
# pylint: disable=too-many-instance-attributes, too-many-arguments, too-many-branches
|
||||
"""
|
||||
This script is a variant of dmlc-core/dmlc_tracker/tracker.py,
|
||||
which is a specialized version for xgboost tasks.
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name, missing-docstring, too-many-arguments, too-many-locals
|
||||
# pylint: disable=too-many-branches, too-many-statements, too-many-instance-attributes
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
import logging
|
||||
from threading import Thread
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from typing import Dict, List, Tuple, Union, Optional
|
||||
from typing import Dict, List, Tuple, Union, Optional, Set
|
||||
|
||||
_RingMap = Dict[int, Tuple[int, int]]
|
||||
_TreeMap = Dict[int, List[int]]
|
||||
@ -28,6 +25,7 @@ class ExSocket:
|
||||
self.sock = sock
|
||||
|
||||
def recvall(self, nbytes: int) -> bytes:
|
||||
"""Receive number of bytes."""
|
||||
res = []
|
||||
nread = 0
|
||||
while nread < nbytes:
|
||||
@ -37,40 +35,47 @@ class ExSocket:
|
||||
return b''.join(res)
|
||||
|
||||
def recvint(self) -> int:
|
||||
"""Receive an integer of 32 bytes"""
|
||||
return struct.unpack('@i', self.recvall(4))[0]
|
||||
|
||||
def sendint(self, n: int) -> None:
|
||||
self.sock.sendall(struct.pack('@i', n))
|
||||
def sendint(self, value: int) -> None:
|
||||
"""Send an integer of 32 bytes"""
|
||||
self.sock.sendall(struct.pack('@i', value))
|
||||
|
||||
def sendstr(self, s: str) -> None:
|
||||
self.sendint(len(s))
|
||||
self.sock.sendall(s.encode())
|
||||
def sendstr(self, value: str) -> None:
|
||||
"""Send a Python string"""
|
||||
self.sendint(len(value))
|
||||
self.sock.sendall(value.encode())
|
||||
|
||||
def recvstr(self) -> str:
|
||||
"""Receive a Python string"""
|
||||
slen = self.recvint()
|
||||
return self.recvall(slen).decode()
|
||||
|
||||
|
||||
# magic number used to verify existence of data
|
||||
kMagic = 0xff99
|
||||
MAGIC_NUM = 0xff99
|
||||
|
||||
|
||||
def get_some_ip(host: str) -> str:
|
||||
"""Get ip from host"""
|
||||
return socket.getaddrinfo(host, None)[0][4][0]
|
||||
|
||||
|
||||
def get_family(addr: str) -> int:
|
||||
"""Get network family from address."""
|
||||
return socket.getaddrinfo(addr, None)[0][0]
|
||||
|
||||
|
||||
class WorkerEntry:
|
||||
"""Hanlder to each worker."""
|
||||
def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]):
|
||||
worker = ExSocket(sock)
|
||||
self.sock = worker
|
||||
self.host = get_some_ip(s_addr[0])
|
||||
magic = worker.recvint()
|
||||
assert magic == kMagic, f"invalid magic number={magic} from {self.host}"
|
||||
worker.sendint(kMagic)
|
||||
assert magic == MAGIC_NUM, f"invalid magic number={magic} from {self.host}"
|
||||
worker.sendint(MAGIC_NUM)
|
||||
self.rank = worker.recvint()
|
||||
self.world_size = worker.recvint()
|
||||
self.jobid = worker.recvstr()
|
||||
@ -78,7 +83,17 @@ class WorkerEntry:
|
||||
self.wait_accept = 0
|
||||
self.port: Optional[int] = None
|
||||
|
||||
def print(self, use_logger: bool) -> None:
|
||||
"""Execute the print command from worker."""
|
||||
msg = self.sock.recvstr()
|
||||
# On dask we use print to avoid setting global verbosity.
|
||||
if use_logger:
|
||||
logging.info(msg.strip())
|
||||
else:
|
||||
print(msg.strip(), flush=True)
|
||||
|
||||
def decide_rank(self, job_map: Dict[str, int]) -> int:
|
||||
"""Get the rank of current entry."""
|
||||
if self.rank >= 0:
|
||||
return self.rank
|
||||
if self.jobid != 'NULL' and self.jobid in job_map:
|
||||
@ -93,6 +108,7 @@ class WorkerEntry:
|
||||
parent_map: Dict[int, int],
|
||||
ring_map: _RingMap,
|
||||
) -> List[int]:
|
||||
"""Assign the rank for current entry."""
|
||||
self.rank = rank
|
||||
nnset = set(tree_map[rank])
|
||||
rprev, rnext = ring_map[rank]
|
||||
@ -117,6 +133,12 @@ class WorkerEntry:
|
||||
self.sock.sendint(rnext)
|
||||
else:
|
||||
self.sock.sendint(-1)
|
||||
|
||||
return self._get_remote(wait_conn, nnset)
|
||||
|
||||
def _get_remote(
|
||||
self, wait_conn: Dict[int, "WorkerEntry"], nnset: Set[int]
|
||||
) -> List[int]:
|
||||
while True:
|
||||
ngood = self.sock.recvint()
|
||||
goodset = set([])
|
||||
@ -158,10 +180,7 @@ class RabitTracker:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, hostIP: str,
|
||||
n_workers: int,
|
||||
port: int = 0,
|
||||
use_logger: bool = False,
|
||||
self, host_ip: str, n_workers: int, port: int = 0, use_logger: bool = False
|
||||
) -> None:
|
||||
"""A Python implementation of RABIT tracker.
|
||||
|
||||
@ -172,23 +191,23 @@ class RabitTracker:
|
||||
function is used instead.
|
||||
|
||||
"""
|
||||
sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM)
|
||||
sock.bind((hostIP, port))
|
||||
sock = socket.socket(get_family(host_ip), socket.SOCK_STREAM)
|
||||
sock.bind((host_ip, port))
|
||||
self.port = sock.getsockname()[1]
|
||||
sock.listen(256)
|
||||
self.sock = sock
|
||||
self.hostIP = hostIP
|
||||
self.host_ip = host_ip
|
||||
self.thread: Optional[Thread] = None
|
||||
self.n_workers = n_workers
|
||||
self._use_logger = use_logger
|
||||
logging.info('start listen on %s:%d', hostIP, self.port)
|
||||
logging.info("start listen on %s:%d", host_ip, self.port)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if hasattr(self, "sock"):
|
||||
self.sock.close()
|
||||
|
||||
@staticmethod
|
||||
def get_neighbor(rank: int, n_workers: int) -> List[int]:
|
||||
def _get_neighbor(rank: int, n_workers: int) -> List[int]:
|
||||
rank = rank + 1
|
||||
ret = []
|
||||
if rank > 1:
|
||||
@ -204,29 +223,28 @@ class RabitTracker:
|
||||
get environment variables for workers
|
||||
can be passed in as args or envs
|
||||
"""
|
||||
return {'DMLC_TRACKER_URI': self.hostIP,
|
||||
'DMLC_TRACKER_PORT': self.port}
|
||||
return {'DMLC_TRACKER_URI': self.host_ip, 'DMLC_TRACKER_PORT': self.port}
|
||||
|
||||
def get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]:
|
||||
def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]:
|
||||
tree_map: _TreeMap = {}
|
||||
parent_map: Dict[int, int] = {}
|
||||
for r in range(n_workers):
|
||||
tree_map[r] = self.get_neighbor(r, n_workers)
|
||||
tree_map[r] = self._get_neighbor(r, n_workers)
|
||||
parent_map[r] = (r + 1) // 2 - 1
|
||||
return tree_map, parent_map
|
||||
|
||||
def find_share_ring(
|
||||
self, tree_map: _TreeMap, parent_map: Dict[int, int], r: int
|
||||
self, tree_map: _TreeMap, parent_map: Dict[int, int], rank: int
|
||||
) -> List[int]:
|
||||
"""
|
||||
get a ring structure that tends to share nodes with the tree
|
||||
return a list starting from r
|
||||
return a list starting from rank
|
||||
"""
|
||||
nset = set(tree_map[r])
|
||||
cset = nset - set([parent_map[r]])
|
||||
nset = set(tree_map[rank])
|
||||
cset = nset - set([parent_map[rank]])
|
||||
if not cset:
|
||||
return [r]
|
||||
rlst = [r]
|
||||
return [rank]
|
||||
rlst = [rank]
|
||||
cnt = 0
|
||||
for v in cset:
|
||||
vlst = self.find_share_ring(tree_map, parent_map, v)
|
||||
@ -256,7 +274,7 @@ class RabitTracker:
|
||||
get the link map, this is a bit hacky, call for better algorithm
|
||||
to place similar nodes together
|
||||
"""
|
||||
tree_map, parent_map = self.get_tree(n_workers)
|
||||
tree_map, parent_map = self._get_tree(n_workers)
|
||||
ring_map = self.get_ring(tree_map, parent_map)
|
||||
rmap = {0: 0}
|
||||
k = 0
|
||||
@ -279,6 +297,7 @@ class RabitTracker:
|
||||
return tree_map_, parent_map_, ring_map_
|
||||
|
||||
def accept_workers(self, n_workers: int) -> None:
|
||||
"""Wait for all workers to connect to the tracker."""
|
||||
# set of nodes that finishes the job
|
||||
shutdown: Dict[int, WorkerEntry] = {}
|
||||
# set of nodes that is waiting for connections
|
||||
@ -290,18 +309,11 @@ class RabitTracker:
|
||||
# lazy initialize tree_map
|
||||
tree_map = None
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
while len(shutdown) != n_workers:
|
||||
fd, s_addr = self.sock.accept()
|
||||
s = WorkerEntry(fd, s_addr)
|
||||
if s.cmd == 'print':
|
||||
msg = s.sock.recvstr()
|
||||
# On dask we use print to avoid setting global verbosity.
|
||||
if self._use_logger:
|
||||
logging.info(msg.strip())
|
||||
else:
|
||||
print(msg.strip(), flush=True)
|
||||
s.print(self._use_logger)
|
||||
continue
|
||||
if s.cmd == 'shutdown':
|
||||
assert s.rank >= 0 and s.rank not in shutdown
|
||||
@ -347,13 +359,9 @@ class RabitTracker:
|
||||
if s.wait_accept > 0:
|
||||
wait_conn[rank] = s
|
||||
logging.info('@tracker All nodes finishes job')
|
||||
end_time = time.time()
|
||||
logging.info(
|
||||
'@tracker %s secs between node start and job finish',
|
||||
str(end_time - start_time)
|
||||
)
|
||||
|
||||
def start(self, n_workers: int) -> None:
|
||||
"""Strat the tracker, it will wait for `n_workers` to connect."""
|
||||
def run() -> None:
|
||||
self.accept_workers(n_workers)
|
||||
|
||||
@ -361,36 +369,42 @@ class RabitTracker:
|
||||
self.thread.start()
|
||||
|
||||
def join(self) -> None:
|
||||
"""Wait for the tracker to finish."""
|
||||
while self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(100)
|
||||
|
||||
def alive(self) -> bool:
|
||||
"""Wether the tracker thread is alive"""
|
||||
return self.thread is not None and self.thread.is_alive()
|
||||
|
||||
|
||||
def get_host_ip(hostIP: Optional[str] = None) -> str:
|
||||
if hostIP is None or hostIP == 'auto':
|
||||
hostIP = 'ip'
|
||||
def get_host_ip(host_ip: Optional[str] = None) -> str:
|
||||
"""Get the IP address of current host. If `host_ip` is not none then it will be
|
||||
returned as it's
|
||||
|
||||
if hostIP == 'dns':
|
||||
hostIP = socket.getfqdn()
|
||||
elif hostIP == 'ip':
|
||||
"""
|
||||
if host_ip is None or host_ip == 'auto':
|
||||
host_ip = 'ip'
|
||||
|
||||
if host_ip == 'dns':
|
||||
host_ip = socket.getfqdn()
|
||||
elif host_ip == 'ip':
|
||||
from socket import gaierror
|
||||
try:
|
||||
hostIP = socket.gethostbyname(socket.getfqdn())
|
||||
host_ip = socket.gethostbyname(socket.getfqdn())
|
||||
except gaierror:
|
||||
logging.debug(
|
||||
'gethostbyname(socket.getfqdn()) failed... trying on hostname()'
|
||||
)
|
||||
hostIP = socket.gethostbyname(socket.gethostname())
|
||||
if hostIP.startswith("127."):
|
||||
host_ip = socket.gethostbyname(socket.gethostname())
|
||||
if host_ip.startswith("127."):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
# doesn't have to be reachable
|
||||
s.connect(('10.255.255.255', 1))
|
||||
hostIP = s.getsockname()[0]
|
||||
host_ip = s.getsockname()[0]
|
||||
|
||||
assert hostIP is not None
|
||||
return hostIP
|
||||
assert host_ip is not None
|
||||
return host_ip
|
||||
|
||||
|
||||
def start_rabit_tracker(args: argparse.Namespace) -> None:
|
||||
@ -402,7 +416,7 @@ def start_rabit_tracker(args: argparse.Namespace) -> None:
|
||||
"""
|
||||
envs = {"DMLC_NUM_WORKER": args.num_workers, "DMLC_NUM_SERVER": args.num_servers}
|
||||
rabit = RabitTracker(
|
||||
hostIP=get_host_ip(args.host_ip), n_workers=args.num_workers, use_logger=True
|
||||
host_ip=get_host_ip(args.host_ip), n_workers=args.num_workers, use_logger=True
|
||||
)
|
||||
envs.update(rabit.worker_envs())
|
||||
rabit.start(args.num_workers)
|
||||
|
||||
@ -10,7 +10,7 @@ if sys.platform.startswith("win"):
|
||||
|
||||
|
||||
def test_rabit_tracker():
|
||||
tracker = RabitTracker(hostIP='127.0.0.1', n_workers=1)
|
||||
tracker = RabitTracker(host_ip='127.0.0.1', n_workers=1)
|
||||
tracker.start(1)
|
||||
worker_env = tracker.worker_envs()
|
||||
rabit_env = []
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user