Fix pylint (#7241)
This commit is contained in:
parent
38a23f66a8
commit
b18f5f61b0
@ -229,9 +229,9 @@ class InstallLib(install_lib.install_lib):
|
|||||||
os.mkdir(lib_dir)
|
os.mkdir(lib_dir)
|
||||||
dst = os.path.join(self.install_dir, 'xgboost', 'lib', lib_name())
|
dst = os.path.join(self.install_dir, 'xgboost', 'lib', lib_name())
|
||||||
|
|
||||||
global BUILD_TEMP_DIR # pylint: disable=global-statement
|
|
||||||
libxgboost_path = lib_name()
|
libxgboost_path = lib_name()
|
||||||
|
|
||||||
|
assert BUILD_TEMP_DIR is not None
|
||||||
dft_lib_dir = os.path.join(CURRENT_DIR, os.path.pardir, 'lib')
|
dft_lib_dir = os.path.join(CURRENT_DIR, os.path.pardir, 'lib')
|
||||||
build_dir = os.path.join(BUILD_TEMP_DIR, 'xgboost', 'lib')
|
build_dir = os.path.join(BUILD_TEMP_DIR, 'xgboost', 'lib')
|
||||||
|
|
||||||
|
|||||||
@ -28,11 +28,11 @@ def _get_callback_context(env):
|
|||||||
def _fmt_metric(value, show_stdv=True):
|
def _fmt_metric(value, show_stdv=True):
|
||||||
"""format metric string"""
|
"""format metric string"""
|
||||||
if len(value) == 2:
|
if len(value) == 2:
|
||||||
return '{0}:{1:.5f}'.format(value[0], value[1])
|
return f"{value[0]}:{value[1]:.5f}"
|
||||||
if len(value) == 3:
|
if len(value) == 3:
|
||||||
if show_stdv:
|
if show_stdv:
|
||||||
return '{0}:{1:.5f}+{2:.5f}'.format(value[0], value[1], value[2])
|
return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}"
|
||||||
return '{0}:{1:.5f}'.format(value[0], value[1])
|
return f"{value[0]}:{value[1]:.5f}"
|
||||||
raise ValueError("wrong metric value", value)
|
raise ValueError("wrong metric value", value)
|
||||||
|
|
||||||
|
|
||||||
@ -62,7 +62,7 @@ def print_evaluation(period=1, show_stdv=True):
|
|||||||
i = env.iteration
|
i = env.iteration
|
||||||
if i % period == 0 or i + 1 == env.begin_iteration or i + 1 == env.end_iteration:
|
if i % period == 0 or i + 1 == env.begin_iteration or i + 1 == env.end_iteration:
|
||||||
msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list])
|
msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list])
|
||||||
rabit.tracker_print('[%d]\t%s\n' % (i, msg))
|
rabit.tracker_print(f"{i}\t{msg}\n")
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
|
|
||||||
@ -217,9 +217,11 @@ def early_stop(stopping_rounds, maximize=False, verbose=True):
|
|||||||
state['best_score'] = float('-inf')
|
state['best_score'] = float('-inf')
|
||||||
else:
|
else:
|
||||||
state['best_score'] = float('inf')
|
state['best_score'] = float('inf')
|
||||||
|
# pylint: disable=consider-using-f-string
|
||||||
msg = '[%d]\t%s' % (
|
msg = '[%d]\t%s' % (
|
||||||
env.iteration,
|
env.iteration,
|
||||||
'\t'.join([_fmt_metric(x) for x in env.evaluation_result_list]))
|
'\t'.join([_fmt_metric(x) for x in env.evaluation_result_list])
|
||||||
|
)
|
||||||
state['best_msg'] = msg
|
state['best_msg'] = msg
|
||||||
|
|
||||||
if bst is not None:
|
if bst is not None:
|
||||||
@ -243,6 +245,7 @@ def early_stop(stopping_rounds, maximize=False, verbose=True):
|
|||||||
maximize_score = state['maximize_score']
|
maximize_score = state['maximize_score']
|
||||||
if (maximize_score and score > best_score) or \
|
if (maximize_score and score > best_score) or \
|
||||||
(not maximize_score and score < best_score):
|
(not maximize_score and score < best_score):
|
||||||
|
# pylint: disable=consider-using-f-string
|
||||||
msg = '[%d]\t%s' % (
|
msg = '[%d]\t%s' % (
|
||||||
env.iteration,
|
env.iteration,
|
||||||
'\t'.join([_fmt_metric(x) for x in env.evaluation_result_list]))
|
'\t'.join([_fmt_metric(x) for x in env.evaluation_result_list]))
|
||||||
@ -363,7 +366,7 @@ class CallbackContainer:
|
|||||||
' will invoke monitor automatically.'
|
' will invoke monitor automatically.'
|
||||||
assert callable(metric), msg
|
assert callable(metric), msg
|
||||||
self.metric = metric
|
self.metric = metric
|
||||||
self.history: CallbackContainer.EvalsLog = collections.OrderedDict()
|
self.history: TrainingCallback.EvalsLog = collections.OrderedDict()
|
||||||
self.is_cv = is_cv
|
self.is_cv = is_cv
|
||||||
|
|
||||||
if self.is_cv:
|
if self.is_cv:
|
||||||
@ -519,7 +522,7 @@ class EarlyStopping(TrainingCallback):
|
|||||||
self.rounds = rounds
|
self.rounds = rounds
|
||||||
self.save_best = save_best
|
self.save_best = save_best
|
||||||
self.maximize = maximize
|
self.maximize = maximize
|
||||||
self.stopping_history: CallbackContainer.EvalsLog = {}
|
self.stopping_history: TrainingCallback.EvalsLog = {}
|
||||||
self._min_delta = min_delta
|
self._min_delta = min_delta
|
||||||
if self._min_delta < 0:
|
if self._min_delta < 0:
|
||||||
raise ValueError("min_delta must be greater or equal to 0.")
|
raise ValueError("min_delta must be greater or equal to 0.")
|
||||||
@ -589,7 +592,7 @@ class EarlyStopping(TrainingCallback):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def after_iteration(self, model, epoch: int,
|
def after_iteration(self, model, epoch: int,
|
||||||
evals_log: CallbackContainer.EvalsLog) -> bool:
|
evals_log: TrainingCallback.EvalsLog) -> bool:
|
||||||
epoch += self.starting_round # training continuation
|
epoch += self.starting_round # training continuation
|
||||||
msg = 'Must have at least 1 validation dataset for early stopping.'
|
msg = 'Must have at least 1 validation dataset for early stopping.'
|
||||||
assert len(evals_log.keys()) >= 1, msg
|
assert len(evals_log.keys()) >= 1, msg
|
||||||
@ -653,15 +656,17 @@ class EvaluationMonitor(TrainingCallback):
|
|||||||
self._latest: Optional[str] = None
|
self._latest: Optional[str] = None
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def _fmt_metric(self, data, metric, score, std) -> str:
|
def _fmt_metric(
|
||||||
|
self, data: str, metric: str, score: float, std: Optional[float]
|
||||||
|
) -> str:
|
||||||
if std is not None and self.show_stdv:
|
if std is not None and self.show_stdv:
|
||||||
msg = '\t{0}:{1:.5f}+{2:.5f}'.format(data + '-' + metric, score, std)
|
msg = f"\t{data + '-' + metric}:{score:.5f}+{std:.5f}"
|
||||||
else:
|
else:
|
||||||
msg = '\t{0}:{1:.5f}'.format(data + '-' + metric, score)
|
msg = f"\t{data + '-' + metric}:{score:.5f}"
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def after_iteration(self, model, epoch: int,
|
def after_iteration(self, model, epoch: int,
|
||||||
evals_log: CallbackContainer.EvalsLog) -> bool:
|
evals_log: TrainingCallback.EvalsLog) -> bool:
|
||||||
if not evals_log:
|
if not evals_log:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -723,7 +728,7 @@ class TrainingCheckPoint(TrainingCallback):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def after_iteration(self, model, epoch: int,
|
def after_iteration(self, model, epoch: int,
|
||||||
evals_log: CallbackContainer.EvalsLog) -> bool:
|
evals_log: TrainingCallback.EvalsLog) -> bool:
|
||||||
if self._epoch == self._iterations:
|
if self._epoch == self._iterations:
|
||||||
path = os.path.join(self._path, self._name + '_' + str(epoch) +
|
path = os.path.join(self._path, self._name + '_' + str(epoch) +
|
||||||
('.pkl' if self._as_pickle else '.json'))
|
('.pkl' if self._as_pickle else '.json'))
|
||||||
|
|||||||
@ -140,9 +140,9 @@ def _expect(expectations, got):
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
def _log_callback(msg):
|
def _log_callback(msg: bytes) -> None:
|
||||||
"""Redirect logs from native library into Python console"""
|
"""Redirect logs from native library into Python console"""
|
||||||
print("{0:s}".format(py_str(msg)))
|
print(py_str(msg))
|
||||||
|
|
||||||
|
|
||||||
def _get_log_callback_func():
|
def _get_log_callback_func():
|
||||||
@ -179,14 +179,19 @@ def _load_lib():
|
|||||||
if not lib_success:
|
if not lib_success:
|
||||||
libname = os.path.basename(lib_paths[0])
|
libname = os.path.basename(lib_paths[0])
|
||||||
raise XGBoostError(
|
raise XGBoostError(
|
||||||
'XGBoost Library ({}) could not be loaded.\n'.format(libname) +
|
f"""
|
||||||
'Likely causes:\n' +
|
XGBoost Library ({libname}) could not be loaded.
|
||||||
' * OpenMP runtime is not installed ' +
|
Likely causes:
|
||||||
'(vcomp140.dll or libgomp-1.dll for Windows, libomp.dylib for Mac OSX, ' +
|
* OpenMP runtime is not installed
|
||||||
'libgomp.so for Linux and other UNIX-like OSes). Mac OSX users: Run ' +
|
- vcomp140.dll or libgomp-1.dll for Windows
|
||||||
'`brew install libomp` to install OpenMP runtime.\n' +
|
- libomp.dylib for Mac OSX
|
||||||
' * You are running 32-bit Python on a 64-bit OS\n' +
|
- libgomp.so for Linux and other UNIX-like OSes
|
||||||
'Error message(s): {}\n'.format(os_error_list))
|
Mac OSX users: Run `brew install libomp` to install OpenMP runtime.
|
||||||
|
|
||||||
|
* You are running 32-bit Python on a 64-bit OS
|
||||||
|
|
||||||
|
Error message(s): {os_error_list}
|
||||||
|
""")
|
||||||
lib.XGBGetLastError.restype = ctypes.c_char_p
|
lib.XGBGetLastError.restype = ctypes.c_char_p
|
||||||
lib.callback = _get_log_callback_func()
|
lib.callback = _get_log_callback_func()
|
||||||
if lib.XGBRegisterLogCallback(lib.callback) != 0:
|
if lib.XGBRegisterLogCallback(lib.callback) != 0:
|
||||||
@ -246,7 +251,7 @@ def ctypes2numpy(cptr, length, dtype):
|
|||||||
"""Convert a ctypes pointer array to a numpy array."""
|
"""Convert a ctypes pointer array to a numpy array."""
|
||||||
ctype = _numpy2ctypes_type(dtype)
|
ctype = _numpy2ctypes_type(dtype)
|
||||||
if not isinstance(cptr, ctypes.POINTER(ctype)):
|
if not isinstance(cptr, ctypes.POINTER(ctype)):
|
||||||
raise RuntimeError("expected {} pointer".format(ctype))
|
raise RuntimeError(f"expected {ctype} pointer")
|
||||||
res = np.zeros(length, dtype=dtype)
|
res = np.zeros(length, dtype=dtype)
|
||||||
if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]):
|
if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]):
|
||||||
raise RuntimeError("memmove failed")
|
raise RuntimeError("memmove failed")
|
||||||
@ -262,7 +267,7 @@ def ctypes2cupy(cptr, length, dtype):
|
|||||||
|
|
||||||
CUPY_TO_CTYPES_MAPPING = {cupy.float32: ctypes.c_float, cupy.uint32: ctypes.c_uint}
|
CUPY_TO_CTYPES_MAPPING = {cupy.float32: ctypes.c_float, cupy.uint32: ctypes.c_uint}
|
||||||
if dtype not in CUPY_TO_CTYPES_MAPPING.keys():
|
if dtype not in CUPY_TO_CTYPES_MAPPING.keys():
|
||||||
raise RuntimeError("Supported types: {}".format(CUPY_TO_CTYPES_MAPPING.keys()))
|
raise RuntimeError(f"Supported types: {CUPY_TO_CTYPES_MAPPING.keys()}")
|
||||||
addr = ctypes.cast(cptr, ctypes.c_void_p).value
|
addr = ctypes.cast(cptr, ctypes.c_void_p).value
|
||||||
# pylint: disable=c-extension-no-member,no-member
|
# pylint: disable=c-extension-no-member,no-member
|
||||||
device = cupy.cuda.runtime.pointerGetAttributes(addr).device
|
device = cupy.cuda.runtime.pointerGetAttributes(addr).device
|
||||||
@ -486,13 +491,16 @@ def _deprecate_positional_args(f):
|
|||||||
if extra_args > 0:
|
if extra_args > 0:
|
||||||
# ignore first 'self' argument for instance methods
|
# ignore first 'self' argument for instance methods
|
||||||
args_msg = [
|
args_msg = [
|
||||||
'{}'.format(name) for name, _ in zip(
|
f"{name}" for name, _ in zip(
|
||||||
kwonly_args[:extra_args], args[-extra_args:])
|
kwonly_args[:extra_args], args[-extra_args:]
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
# pylint: disable=consider-using-f-string
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Pass `{}` as keyword args. Passing these as positional "
|
"Pass `{}` as keyword args. Passing these as positional "
|
||||||
"arguments will be considered as error in future releases.".
|
"arguments will be considered as error in future releases.".
|
||||||
format(", ".join(args_msg)), FutureWarning)
|
format(", ".join(args_msg)), FutureWarning
|
||||||
|
)
|
||||||
for k, arg in zip(sig.parameters, args):
|
for k, arg in zip(sig.parameters, args):
|
||||||
kwargs[k] = arg
|
kwargs[k] = arg
|
||||||
return f(**kwargs)
|
return f(**kwargs)
|
||||||
@ -1292,7 +1300,7 @@ class Booster(object):
|
|||||||
"""
|
"""
|
||||||
for d in cache:
|
for d in cache:
|
||||||
if not isinstance(d, DMatrix):
|
if not isinstance(d, DMatrix):
|
||||||
raise TypeError('invalid cache item: {}'.format(type(d).__name__), cache)
|
raise TypeError(f'invalid cache item: {type(d).__name__}', cache)
|
||||||
|
|
||||||
dmats = c_array(ctypes.c_void_p, [d.handle for d in cache])
|
dmats = c_array(ctypes.c_void_p, [d.handle for d in cache])
|
||||||
self.handle = ctypes.c_void_p()
|
self.handle = ctypes.c_void_p()
|
||||||
@ -1665,8 +1673,7 @@ class Booster(object):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if not isinstance(dtrain, DMatrix):
|
if not isinstance(dtrain, DMatrix):
|
||||||
raise TypeError('invalid training matrix: {}'.format(
|
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
|
||||||
type(dtrain).__name__))
|
|
||||||
self._validate_features(dtrain)
|
self._validate_features(dtrain)
|
||||||
|
|
||||||
if fobj is None:
|
if fobj is None:
|
||||||
@ -1695,10 +1702,10 @@ class Booster(object):
|
|||||||
"""
|
"""
|
||||||
if len(grad) != len(hess):
|
if len(grad) != len(hess):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'grad / hess length mismatch: {} / {}'.format(len(grad), len(hess))
|
f"grad / hess length mismatch: {len(grad)} / {len(hess)}"
|
||||||
)
|
)
|
||||||
if not isinstance(dtrain, DMatrix):
|
if not isinstance(dtrain, DMatrix):
|
||||||
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
|
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
|
||||||
self._validate_features(dtrain)
|
self._validate_features(dtrain)
|
||||||
|
|
||||||
_check_call(_LIB.XGBoosterBoostOneIter(self.handle, dtrain.handle,
|
_check_call(_LIB.XGBoosterBoostOneIter(self.handle, dtrain.handle,
|
||||||
@ -1726,11 +1733,9 @@ class Booster(object):
|
|||||||
"""
|
"""
|
||||||
for d in evals:
|
for d in evals:
|
||||||
if not isinstance(d[0], DMatrix):
|
if not isinstance(d[0], DMatrix):
|
||||||
raise TypeError('expected DMatrix, got {}'.format(
|
raise TypeError(f"expected DMatrix, got {type(d[0]).__name__}")
|
||||||
type(d[0]).__name__))
|
|
||||||
if not isinstance(d[1], STRING_TYPES):
|
if not isinstance(d[1], STRING_TYPES):
|
||||||
raise TypeError('expected string, got {}'.format(
|
raise TypeError(f"expected string, got {type(d[1]).__name__}")
|
||||||
type(d[1]).__name__))
|
|
||||||
self._validate_features(d[0])
|
self._validate_features(d[0])
|
||||||
|
|
||||||
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
|
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
|
||||||
@ -1748,9 +1753,11 @@ class Booster(object):
|
|||||||
output_margin=True), dmat)
|
output_margin=True), dmat)
|
||||||
if isinstance(feval_ret, list):
|
if isinstance(feval_ret, list):
|
||||||
for name, val in feval_ret:
|
for name, val in feval_ret:
|
||||||
|
# pylint: disable=consider-using-f-string
|
||||||
res += '\t%s-%s:%f' % (evname, name, val)
|
res += '\t%s-%s:%f' % (evname, name, val)
|
||||||
else:
|
else:
|
||||||
name, val = feval_ret
|
name, val = feval_ret
|
||||||
|
# pylint: disable=consider-using-f-string
|
||||||
res += '\t%s-%s:%f' % (evname, name, val)
|
res += '\t%s-%s:%f' % (evname, name, val)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@ -2218,7 +2225,7 @@ class Booster(object):
|
|||||||
fout.write('\n]')
|
fout.write('\n]')
|
||||||
else:
|
else:
|
||||||
for i, _ in enumerate(ret):
|
for i, _ in enumerate(ret):
|
||||||
fout.write('booster[{}]:\n'.format(i))
|
fout.write(f"booster[{i}]:\n")
|
||||||
fout.write(ret[i])
|
fout.write(ret[i])
|
||||||
if need_close:
|
if need_close:
|
||||||
fout.close()
|
fout.close()
|
||||||
@ -2353,8 +2360,9 @@ class Booster(object):
|
|||||||
'Install pandas before calling again.'))
|
'Install pandas before calling again.'))
|
||||||
|
|
||||||
if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}:
|
if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}:
|
||||||
raise ValueError('This method is not defined for Booster type {}'
|
raise ValueError(
|
||||||
.format(self.booster))
|
f"This method is not defined for Booster type {self.booster}"
|
||||||
|
)
|
||||||
|
|
||||||
tree_ids = []
|
tree_ids = []
|
||||||
node_ids = []
|
node_ids = []
|
||||||
@ -2497,6 +2505,7 @@ class Booster(object):
|
|||||||
"""
|
"""
|
||||||
xgdump = self.get_dump(fmap=fmap)
|
xgdump = self.get_dump(fmap=fmap)
|
||||||
values = []
|
values = []
|
||||||
|
# pylint: disable=consider-using-f-string
|
||||||
regexp = re.compile(r"\[{0}<([\d.Ee+-]+)\]".format(feature))
|
regexp = re.compile(r"\[{0}<([\d.Ee+-]+)\]".format(feature))
|
||||||
for i, _ in enumerate(xgdump):
|
for i, _ in enumerate(xgdump):
|
||||||
m = re.findall(regexp, xgdump[i])
|
m = re.findall(regexp, xgdump[i])
|
||||||
@ -2514,7 +2523,7 @@ class Booster(object):
|
|||||||
fn = self.feature_names
|
fn = self.feature_names
|
||||||
if fn is None:
|
if fn is None:
|
||||||
# Let xgboost generate the feature names.
|
# Let xgboost generate the feature names.
|
||||||
fn = ["f{0}".format(i) for i in range(self.num_features())]
|
fn = [f"f{i}" for i in range(self.num_features())]
|
||||||
try:
|
try:
|
||||||
index = fn.index(feature)
|
index = fn.index(feature)
|
||||||
feature_t = ft[index]
|
feature_t = ft[index]
|
||||||
|
|||||||
@ -282,7 +282,7 @@ class DaskDMatrix:
|
|||||||
|
|
||||||
if len(data.shape) != 2:
|
if len(data.shape) != 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Expecting 2 dimensional input, got: {shape}".format(shape=data.shape)
|
f"Expecting 2 dimensional input, got: {data.shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(data, (dd.DataFrame, da.Array)):
|
if not isinstance(data, (dd.DataFrame, da.Array)):
|
||||||
@ -328,12 +328,9 @@ class DaskDMatrix:
|
|||||||
def inconsistent(
|
def inconsistent(
|
||||||
left: List[Any], left_name: str, right: List[Any], right_name: str
|
left: List[Any], left_name: str, right: List[Any], right_name: str
|
||||||
) -> str:
|
) -> str:
|
||||||
msg = 'Partitions between {a_name} and {b_name} are not ' \
|
msg = (f"Partitions between {left_name} and {right_name} are not "
|
||||||
'consistent: {a_len} != {b_len}. ' \
|
f"consistent: {len(left)} != {len(right)}. "
|
||||||
'Please try to repartition/rechunk your data.'.format(
|
f"Please try to repartition/rechunk your data.")
|
||||||
a_name=left_name, b_name=right_name, a_len=len(left),
|
|
||||||
b_len=len(right)
|
|
||||||
)
|
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def check_columns(parts: Any) -> None:
|
def check_columns(parts: Any) -> None:
|
||||||
@ -683,7 +680,7 @@ def _create_device_quantile_dmatrix(
|
|||||||
) -> DeviceQuantileDMatrix:
|
) -> DeviceQuantileDMatrix:
|
||||||
worker = distributed.get_worker()
|
worker = distributed.get_worker()
|
||||||
if parts is None:
|
if parts is None:
|
||||||
msg = "worker {address} has an empty DMatrix.".format(address=worker.address)
|
msg = f"worker {worker.address} has an empty DMatrix."
|
||||||
LOGGER.warning(msg)
|
LOGGER.warning(msg)
|
||||||
import cupy
|
import cupy
|
||||||
|
|
||||||
@ -747,7 +744,7 @@ def _create_dmatrix(
|
|||||||
worker = distributed.get_worker()
|
worker = distributed.get_worker()
|
||||||
list_of_parts = parts
|
list_of_parts = parts
|
||||||
if list_of_parts is None:
|
if list_of_parts is None:
|
||||||
msg = 'worker {address} has an empty DMatrix. '.format(address=worker.address)
|
msg = f"worker {worker.address} has an empty DMatrix."
|
||||||
LOGGER.warning(msg)
|
LOGGER.warning(msg)
|
||||||
d = DMatrix(
|
d = DMatrix(
|
||||||
numpy.empty((0, 0)),
|
numpy.empty((0, 0)),
|
||||||
@ -806,7 +803,7 @@ def _dmatrix_from_list_of_parts(
|
|||||||
async def _get_rabit_args(n_workers: int, client: "distributed.Client") -> List[bytes]:
|
async def _get_rabit_args(n_workers: int, client: "distributed.Client") -> List[bytes]:
|
||||||
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
|
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
|
||||||
env = await client.run_on_scheduler(_start_tracker, n_workers)
|
env = await client.run_on_scheduler(_start_tracker, n_workers)
|
||||||
rabit_args = [('%s=%s' % item).encode() for item in env.items()]
|
rabit_args = [f"{k}={v}".encode() for k, v in env.items()]
|
||||||
return rabit_args
|
return rabit_args
|
||||||
|
|
||||||
# train and predict methods are supposed to be "functional", which meets the
|
# train and predict methods are supposed to be "functional", which meets the
|
||||||
|
|||||||
@ -69,7 +69,7 @@ def _from_scipy_csr(
|
|||||||
"""Initialize data from a CSR matrix."""
|
"""Initialize data from a CSR matrix."""
|
||||||
if len(data.indices) != len(data.data):
|
if len(data.indices) != len(data.data):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"length mismatch: {} vs {}".format(len(data.indices), len(data.data))
|
f"length mismatch: {len(data.indices)} vs {len(data.data)}"
|
||||||
)
|
)
|
||||||
handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
args = {
|
args = {
|
||||||
@ -106,8 +106,7 @@ def _from_scipy_csc(
|
|||||||
feature_types: Optional[List[str]],
|
feature_types: Optional[List[str]],
|
||||||
):
|
):
|
||||||
if len(data.indices) != len(data.data):
|
if len(data.indices) != len(data.data):
|
||||||
raise ValueError('length mismatch: {} vs {}'.format(
|
raise ValueError(f"length mismatch: {len(data.indices)} vs {len(data.data)}")
|
||||||
len(data.indices), len(data.data)))
|
|
||||||
_warn_unused_missing(data, missing)
|
_warn_unused_missing(data, missing)
|
||||||
handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
_check_call(_LIB.XGDMatrixCreateFromCSCEx(
|
_check_call(_LIB.XGDMatrixCreateFromCSCEx(
|
||||||
@ -277,8 +276,7 @@ def _transform_pandas_df(
|
|||||||
|
|
||||||
if meta and len(data.columns) > 1:
|
if meta and len(data.columns) > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'DataFrame for {meta} cannot have multiple columns'.format(
|
f"DataFrame for {meta} cannot have multiple columns"
|
||||||
meta=meta)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
dtype = meta_type if meta_type else np.float32
|
dtype = meta_type if meta_type else np.float32
|
||||||
|
|||||||
@ -184,7 +184,7 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir=None,
|
|||||||
kwargs['graph_attrs'] = {}
|
kwargs['graph_attrs'] = {}
|
||||||
kwargs['graph_attrs']['rankdir'] = rankdir
|
kwargs['graph_attrs']['rankdir'] = rankdir
|
||||||
for key, value in extra.items():
|
for key, value in extra.items():
|
||||||
if 'graph_attrs' in kwargs.keys():
|
if kwargs.get("graph_attrs", None) is not None:
|
||||||
kwargs['graph_attrs'][key] = value
|
kwargs['graph_attrs'][key] = value
|
||||||
else:
|
else:
|
||||||
kwargs['graph_attrs'] = {}
|
kwargs['graph_attrs'] = {}
|
||||||
|
|||||||
@ -185,7 +185,7 @@ def allreduce(data, op, prepare_fun=None):
|
|||||||
if buf.base is data.base:
|
if buf.base is data.base:
|
||||||
buf = buf.copy()
|
buf = buf.copy()
|
||||||
if buf.dtype not in DTYPE_ENUM__:
|
if buf.dtype not in DTYPE_ENUM__:
|
||||||
raise Exception('data type %s not supported' % str(buf.dtype))
|
raise Exception(f"data type {buf.dtype} not supported")
|
||||||
if prepare_fun is None:
|
if prepare_fun is None:
|
||||||
_check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
_check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||||
|
|||||||
@ -335,7 +335,7 @@ def _wrap_evaluation_matrices(
|
|||||||
)
|
)
|
||||||
evals.append(m)
|
evals.append(m)
|
||||||
nevals = len(evals)
|
nevals = len(evals)
|
||||||
eval_names = ["validation_{}".format(i) for i in range(nevals)]
|
eval_names = [f"validation_{i}" for i in range(nevals)]
|
||||||
evals = list(zip(evals, eval_names))
|
evals = list(zip(evals, eval_names))
|
||||||
else:
|
else:
|
||||||
if any(
|
if any(
|
||||||
@ -526,7 +526,7 @@ class XGBModel(XGBModelBase):
|
|||||||
stack.append(v)
|
stack.append(v)
|
||||||
|
|
||||||
for k, v in internal.items():
|
for k, v in internal.items():
|
||||||
if k in params.keys() and params[k] is None:
|
if k in params and params[k] is None:
|
||||||
params[k] = parse_parameter(v)
|
params[k] = parse_parameter(v)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
@ -1016,7 +1016,7 @@ class XGBModel(XGBModelBase):
|
|||||||
importance_type=self.importance_type if self.importance_type else dft()
|
importance_type=self.importance_type if self.importance_type else dft()
|
||||||
)
|
)
|
||||||
if b.feature_names is None:
|
if b.feature_names is None:
|
||||||
feature_names = ["f{0}".format(i) for i in range(self.n_features_in_)]
|
feature_names = [f"f{i}" for i in range(self.n_features_in_)]
|
||||||
else:
|
else:
|
||||||
feature_names = b.feature_names
|
feature_names = b.feature_names
|
||||||
# gblinear returns all features so the `get` in next line is only for gbtree.
|
# gblinear returns all features so the `get` in next line is only for gbtree.
|
||||||
@ -1044,8 +1044,8 @@ class XGBModel(XGBModelBase):
|
|||||||
"""
|
"""
|
||||||
if self.get_params()['booster'] != 'gblinear':
|
if self.get_params()['booster'] != 'gblinear':
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
'Coefficients are not defined for Booster type {}'
|
f"Coefficients are not defined for Booster type {self.booster}"
|
||||||
.format(self.booster))
|
)
|
||||||
b = self.get_booster()
|
b = self.get_booster()
|
||||||
coef = np.array(json.loads(b.get_dump(dump_format='json')[0])['weight'])
|
coef = np.array(json.loads(b.get_dump(dump_format='json')[0])['weight'])
|
||||||
# Logic for multiclass classification
|
# Logic for multiclass classification
|
||||||
@ -1074,8 +1074,8 @@ class XGBModel(XGBModelBase):
|
|||||||
"""
|
"""
|
||||||
if self.get_params()['booster'] != 'gblinear':
|
if self.get_params()['booster'] != 'gblinear':
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
'Intercept (bias) is not defined for Booster type {}'
|
f"Intercept (bias) is not defined for Booster type {self.booster}"
|
||||||
.format(self.booster))
|
)
|
||||||
b = self.get_booster()
|
b = self.get_booster()
|
||||||
return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias'])
|
return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias'])
|
||||||
|
|
||||||
|
|||||||
@ -64,7 +64,7 @@ class SlaveEntry(object):
|
|||||||
self.sock = slave
|
self.sock = slave
|
||||||
self.host = get_some_ip(s_addr[0])
|
self.host = get_some_ip(s_addr[0])
|
||||||
magic = slave.recvint()
|
magic = slave.recvint()
|
||||||
assert magic == kMagic, 'invalid magic number=%d from %s' % (magic, self.host)
|
assert magic == kMagic, f"invalid magic number={magic} from {self.host}"
|
||||||
slave.sendint(kMagic)
|
slave.sendint(kMagic)
|
||||||
self.rank = slave.recvint()
|
self.rank = slave.recvint()
|
||||||
self.world_size = slave.recvint()
|
self.world_size = slave.recvint()
|
||||||
@ -296,7 +296,7 @@ class RabitTracker(object):
|
|||||||
shutdown[s.rank] = s
|
shutdown[s.rank] = s
|
||||||
logging.debug('Received %s signal from %d', s.cmd, s.rank)
|
logging.debug('Received %s signal from %d', s.cmd, s.rank)
|
||||||
continue
|
continue
|
||||||
assert s.cmd == 'start' or s.cmd == 'recover'
|
assert s.cmd in ("start", "recover")
|
||||||
# lazily initialize the slaves
|
# lazily initialize the slaves
|
||||||
if tree_map is None:
|
if tree_map is None:
|
||||||
assert s.cmd == 'start'
|
assert s.cmd == 'start'
|
||||||
@ -306,7 +306,7 @@ class RabitTracker(object):
|
|||||||
# set of nodes that is pending for getting up
|
# set of nodes that is pending for getting up
|
||||||
todo_nodes = list(range(nslave))
|
todo_nodes = list(range(nslave))
|
||||||
else:
|
else:
|
||||||
assert s.world_size == -1 or s.world_size == nslave
|
assert s.world_size in (-1, nslave)
|
||||||
if s.cmd == 'recover':
|
if s.cmd == 'recover':
|
||||||
assert s.rank >= 0
|
assert s.rank >= 0
|
||||||
|
|
||||||
@ -392,7 +392,7 @@ def start_rabit_tracker(args):
|
|||||||
sys.stdout.write('DMLC_TRACKER_ENV_START\n')
|
sys.stdout.write('DMLC_TRACKER_ENV_START\n')
|
||||||
# simply write configuration to stdout
|
# simply write configuration to stdout
|
||||||
for k, v in envs.items():
|
for k, v in envs.items():
|
||||||
sys.stdout.write('%s=%s\n' % (k, str(v)))
|
sys.stdout.write(f"{k}={v}\n")
|
||||||
sys.stdout.write('DMLC_TRACKER_ENV_END\n')
|
sys.stdout.write('DMLC_TRACKER_ENV_END\n')
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
rabit.join()
|
rabit.join()
|
||||||
@ -419,7 +419,7 @@ def main():
|
|||||||
elif args.log_level == 'DEBUG':
|
elif args.log_level == 'DEBUG':
|
||||||
level = logging.DEBUG
|
level = logging.DEBUG
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unknown logging level %s" % args.log_level)
|
raise RuntimeError(f"Unknown logging level {args.log_level}")
|
||||||
|
|
||||||
logging.basicConfig(format=fmt, level=level)
|
logging.basicConfig(format=fmt, level=level)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user