Fix pylint (#7241)

This commit is contained in:
Jiaming Yuan 2021-09-17 11:50:36 +08:00 committed by GitHub
parent 38a23f66a8
commit b18f5f61b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 80 additions and 71 deletions

View File

@ -229,9 +229,9 @@ class InstallLib(install_lib.install_lib):
os.mkdir(lib_dir) os.mkdir(lib_dir)
dst = os.path.join(self.install_dir, 'xgboost', 'lib', lib_name()) dst = os.path.join(self.install_dir, 'xgboost', 'lib', lib_name())
global BUILD_TEMP_DIR # pylint: disable=global-statement
libxgboost_path = lib_name() libxgboost_path = lib_name()
assert BUILD_TEMP_DIR is not None
dft_lib_dir = os.path.join(CURRENT_DIR, os.path.pardir, 'lib') dft_lib_dir = os.path.join(CURRENT_DIR, os.path.pardir, 'lib')
build_dir = os.path.join(BUILD_TEMP_DIR, 'xgboost', 'lib') build_dir = os.path.join(BUILD_TEMP_DIR, 'xgboost', 'lib')

View File

@ -28,11 +28,11 @@ def _get_callback_context(env):
def _fmt_metric(value, show_stdv=True): def _fmt_metric(value, show_stdv=True):
"""format metric string""" """format metric string"""
if len(value) == 2: if len(value) == 2:
return '{0}:{1:.5f}'.format(value[0], value[1]) return f"{value[0]}:{value[1]:.5f}"
if len(value) == 3: if len(value) == 3:
if show_stdv: if show_stdv:
return '{0}:{1:.5f}+{2:.5f}'.format(value[0], value[1], value[2]) return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}"
return '{0}:{1:.5f}'.format(value[0], value[1]) return f"{value[0]}:{value[1]:.5f}"
raise ValueError("wrong metric value", value) raise ValueError("wrong metric value", value)
@ -62,7 +62,7 @@ def print_evaluation(period=1, show_stdv=True):
i = env.iteration i = env.iteration
if i % period == 0 or i + 1 == env.begin_iteration or i + 1 == env.end_iteration: if i % period == 0 or i + 1 == env.begin_iteration or i + 1 == env.end_iteration:
msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list]) msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list])
rabit.tracker_print('[%d]\t%s\n' % (i, msg)) rabit.tracker_print(f"{i}\t{msg}\n")
return callback return callback
@ -217,9 +217,11 @@ def early_stop(stopping_rounds, maximize=False, verbose=True):
state['best_score'] = float('-inf') state['best_score'] = float('-inf')
else: else:
state['best_score'] = float('inf') state['best_score'] = float('inf')
# pylint: disable=consider-using-f-string
msg = '[%d]\t%s' % ( msg = '[%d]\t%s' % (
env.iteration, env.iteration,
'\t'.join([_fmt_metric(x) for x in env.evaluation_result_list])) '\t'.join([_fmt_metric(x) for x in env.evaluation_result_list])
)
state['best_msg'] = msg state['best_msg'] = msg
if bst is not None: if bst is not None:
@ -243,6 +245,7 @@ def early_stop(stopping_rounds, maximize=False, verbose=True):
maximize_score = state['maximize_score'] maximize_score = state['maximize_score']
if (maximize_score and score > best_score) or \ if (maximize_score and score > best_score) or \
(not maximize_score and score < best_score): (not maximize_score and score < best_score):
# pylint: disable=consider-using-f-string
msg = '[%d]\t%s' % ( msg = '[%d]\t%s' % (
env.iteration, env.iteration,
'\t'.join([_fmt_metric(x) for x in env.evaluation_result_list])) '\t'.join([_fmt_metric(x) for x in env.evaluation_result_list]))
@ -363,7 +366,7 @@ class CallbackContainer:
' will invoke monitor automatically.' ' will invoke monitor automatically.'
assert callable(metric), msg assert callable(metric), msg
self.metric = metric self.metric = metric
self.history: CallbackContainer.EvalsLog = collections.OrderedDict() self.history: TrainingCallback.EvalsLog = collections.OrderedDict()
self.is_cv = is_cv self.is_cv = is_cv
if self.is_cv: if self.is_cv:
@ -519,7 +522,7 @@ class EarlyStopping(TrainingCallback):
self.rounds = rounds self.rounds = rounds
self.save_best = save_best self.save_best = save_best
self.maximize = maximize self.maximize = maximize
self.stopping_history: CallbackContainer.EvalsLog = {} self.stopping_history: TrainingCallback.EvalsLog = {}
self._min_delta = min_delta self._min_delta = min_delta
if self._min_delta < 0: if self._min_delta < 0:
raise ValueError("min_delta must be greater or equal to 0.") raise ValueError("min_delta must be greater or equal to 0.")
@ -589,7 +592,7 @@ class EarlyStopping(TrainingCallback):
return False return False
def after_iteration(self, model, epoch: int, def after_iteration(self, model, epoch: int,
evals_log: CallbackContainer.EvalsLog) -> bool: evals_log: TrainingCallback.EvalsLog) -> bool:
epoch += self.starting_round # training continuation epoch += self.starting_round # training continuation
msg = 'Must have at least 1 validation dataset for early stopping.' msg = 'Must have at least 1 validation dataset for early stopping.'
assert len(evals_log.keys()) >= 1, msg assert len(evals_log.keys()) >= 1, msg
@ -653,15 +656,17 @@ class EvaluationMonitor(TrainingCallback):
self._latest: Optional[str] = None self._latest: Optional[str] = None
super().__init__() super().__init__()
def _fmt_metric(self, data, metric, score, std) -> str: def _fmt_metric(
self, data: str, metric: str, score: float, std: Optional[float]
) -> str:
if std is not None and self.show_stdv: if std is not None and self.show_stdv:
msg = '\t{0}:{1:.5f}+{2:.5f}'.format(data + '-' + metric, score, std) msg = f"\t{data + '-' + metric}:{score:.5f}+{std:.5f}"
else: else:
msg = '\t{0}:{1:.5f}'.format(data + '-' + metric, score) msg = f"\t{data + '-' + metric}:{score:.5f}"
return msg return msg
def after_iteration(self, model, epoch: int, def after_iteration(self, model, epoch: int,
evals_log: CallbackContainer.EvalsLog) -> bool: evals_log: TrainingCallback.EvalsLog) -> bool:
if not evals_log: if not evals_log:
return False return False
@ -723,7 +728,7 @@ class TrainingCheckPoint(TrainingCallback):
super().__init__() super().__init__()
def after_iteration(self, model, epoch: int, def after_iteration(self, model, epoch: int,
evals_log: CallbackContainer.EvalsLog) -> bool: evals_log: TrainingCallback.EvalsLog) -> bool:
if self._epoch == self._iterations: if self._epoch == self._iterations:
path = os.path.join(self._path, self._name + '_' + str(epoch) + path = os.path.join(self._path, self._name + '_' + str(epoch) +
('.pkl' if self._as_pickle else '.json')) ('.pkl' if self._as_pickle else '.json'))

View File

@ -140,9 +140,9 @@ def _expect(expectations, got):
return msg return msg
def _log_callback(msg): def _log_callback(msg: bytes) -> None:
"""Redirect logs from native library into Python console""" """Redirect logs from native library into Python console"""
print("{0:s}".format(py_str(msg))) print(py_str(msg))
def _get_log_callback_func(): def _get_log_callback_func():
@ -179,14 +179,19 @@ def _load_lib():
if not lib_success: if not lib_success:
libname = os.path.basename(lib_paths[0]) libname = os.path.basename(lib_paths[0])
raise XGBoostError( raise XGBoostError(
'XGBoost Library ({}) could not be loaded.\n'.format(libname) + f"""
'Likely causes:\n' + XGBoost Library ({libname}) could not be loaded.
' * OpenMP runtime is not installed ' + Likely causes:
'(vcomp140.dll or libgomp-1.dll for Windows, libomp.dylib for Mac OSX, ' + * OpenMP runtime is not installed
'libgomp.so for Linux and other UNIX-like OSes). Mac OSX users: Run ' + - vcomp140.dll or libgomp-1.dll for Windows
'`brew install libomp` to install OpenMP runtime.\n' + - libomp.dylib for Mac OSX
' * You are running 32-bit Python on a 64-bit OS\n' + - libgomp.so for Linux and other UNIX-like OSes
'Error message(s): {}\n'.format(os_error_list)) Mac OSX users: Run `brew install libomp` to install OpenMP runtime.
* You are running 32-bit Python on a 64-bit OS
Error message(s): {os_error_list}
""")
lib.XGBGetLastError.restype = ctypes.c_char_p lib.XGBGetLastError.restype = ctypes.c_char_p
lib.callback = _get_log_callback_func() lib.callback = _get_log_callback_func()
if lib.XGBRegisterLogCallback(lib.callback) != 0: if lib.XGBRegisterLogCallback(lib.callback) != 0:
@ -246,7 +251,7 @@ def ctypes2numpy(cptr, length, dtype):
"""Convert a ctypes pointer array to a numpy array.""" """Convert a ctypes pointer array to a numpy array."""
ctype = _numpy2ctypes_type(dtype) ctype = _numpy2ctypes_type(dtype)
if not isinstance(cptr, ctypes.POINTER(ctype)): if not isinstance(cptr, ctypes.POINTER(ctype)):
raise RuntimeError("expected {} pointer".format(ctype)) raise RuntimeError(f"expected {ctype} pointer")
res = np.zeros(length, dtype=dtype) res = np.zeros(length, dtype=dtype)
if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]): if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]):
raise RuntimeError("memmove failed") raise RuntimeError("memmove failed")
@ -262,7 +267,7 @@ def ctypes2cupy(cptr, length, dtype):
CUPY_TO_CTYPES_MAPPING = {cupy.float32: ctypes.c_float, cupy.uint32: ctypes.c_uint} CUPY_TO_CTYPES_MAPPING = {cupy.float32: ctypes.c_float, cupy.uint32: ctypes.c_uint}
if dtype not in CUPY_TO_CTYPES_MAPPING.keys(): if dtype not in CUPY_TO_CTYPES_MAPPING.keys():
raise RuntimeError("Supported types: {}".format(CUPY_TO_CTYPES_MAPPING.keys())) raise RuntimeError(f"Supported types: {CUPY_TO_CTYPES_MAPPING.keys()}")
addr = ctypes.cast(cptr, ctypes.c_void_p).value addr = ctypes.cast(cptr, ctypes.c_void_p).value
# pylint: disable=c-extension-no-member,no-member # pylint: disable=c-extension-no-member,no-member
device = cupy.cuda.runtime.pointerGetAttributes(addr).device device = cupy.cuda.runtime.pointerGetAttributes(addr).device
@ -486,13 +491,16 @@ def _deprecate_positional_args(f):
if extra_args > 0: if extra_args > 0:
# ignore first 'self' argument for instance methods # ignore first 'self' argument for instance methods
args_msg = [ args_msg = [
'{}'.format(name) for name, _ in zip( f"{name}" for name, _ in zip(
kwonly_args[:extra_args], args[-extra_args:]) kwonly_args[:extra_args], args[-extra_args:]
)
] ]
# pylint: disable=consider-using-f-string
warnings.warn( warnings.warn(
"Pass `{}` as keyword args. Passing these as positional " "Pass `{}` as keyword args. Passing these as positional "
"arguments will be considered as error in future releases.". "arguments will be considered as error in future releases.".
format(", ".join(args_msg)), FutureWarning) format(", ".join(args_msg)), FutureWarning
)
for k, arg in zip(sig.parameters, args): for k, arg in zip(sig.parameters, args):
kwargs[k] = arg kwargs[k] = arg
return f(**kwargs) return f(**kwargs)
@ -1292,7 +1300,7 @@ class Booster(object):
""" """
for d in cache: for d in cache:
if not isinstance(d, DMatrix): if not isinstance(d, DMatrix):
raise TypeError('invalid cache item: {}'.format(type(d).__name__), cache) raise TypeError(f'invalid cache item: {type(d).__name__}', cache)
dmats = c_array(ctypes.c_void_p, [d.handle for d in cache]) dmats = c_array(ctypes.c_void_p, [d.handle for d in cache])
self.handle = ctypes.c_void_p() self.handle = ctypes.c_void_p()
@ -1665,8 +1673,7 @@ class Booster(object):
""" """
if not isinstance(dtrain, DMatrix): if not isinstance(dtrain, DMatrix):
raise TypeError('invalid training matrix: {}'.format( raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
type(dtrain).__name__))
self._validate_features(dtrain) self._validate_features(dtrain)
if fobj is None: if fobj is None:
@ -1695,10 +1702,10 @@ class Booster(object):
""" """
if len(grad) != len(hess): if len(grad) != len(hess):
raise ValueError( raise ValueError(
'grad / hess length mismatch: {} / {}'.format(len(grad), len(hess)) f"grad / hess length mismatch: {len(grad)} / {len(hess)}"
) )
if not isinstance(dtrain, DMatrix): if not isinstance(dtrain, DMatrix):
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__)) raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
self._validate_features(dtrain) self._validate_features(dtrain)
_check_call(_LIB.XGBoosterBoostOneIter(self.handle, dtrain.handle, _check_call(_LIB.XGBoosterBoostOneIter(self.handle, dtrain.handle,
@ -1726,11 +1733,9 @@ class Booster(object):
""" """
for d in evals: for d in evals:
if not isinstance(d[0], DMatrix): if not isinstance(d[0], DMatrix):
raise TypeError('expected DMatrix, got {}'.format( raise TypeError(f"expected DMatrix, got {type(d[0]).__name__}")
type(d[0]).__name__))
if not isinstance(d[1], STRING_TYPES): if not isinstance(d[1], STRING_TYPES):
raise TypeError('expected string, got {}'.format( raise TypeError(f"expected string, got {type(d[1]).__name__}")
type(d[1]).__name__))
self._validate_features(d[0]) self._validate_features(d[0])
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals]) dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
@ -1748,9 +1753,11 @@ class Booster(object):
output_margin=True), dmat) output_margin=True), dmat)
if isinstance(feval_ret, list): if isinstance(feval_ret, list):
for name, val in feval_ret: for name, val in feval_ret:
# pylint: disable=consider-using-f-string
res += '\t%s-%s:%f' % (evname, name, val) res += '\t%s-%s:%f' % (evname, name, val)
else: else:
name, val = feval_ret name, val = feval_ret
# pylint: disable=consider-using-f-string
res += '\t%s-%s:%f' % (evname, name, val) res += '\t%s-%s:%f' % (evname, name, val)
return res return res
@ -2218,7 +2225,7 @@ class Booster(object):
fout.write('\n]') fout.write('\n]')
else: else:
for i, _ in enumerate(ret): for i, _ in enumerate(ret):
fout.write('booster[{}]:\n'.format(i)) fout.write(f"booster[{i}]:\n")
fout.write(ret[i]) fout.write(ret[i])
if need_close: if need_close:
fout.close() fout.close()
@ -2353,8 +2360,9 @@ class Booster(object):
'Install pandas before calling again.')) 'Install pandas before calling again.'))
if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}: if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}:
raise ValueError('This method is not defined for Booster type {}' raise ValueError(
.format(self.booster)) f"This method is not defined for Booster type {self.booster}"
)
tree_ids = [] tree_ids = []
node_ids = [] node_ids = []
@ -2497,6 +2505,7 @@ class Booster(object):
""" """
xgdump = self.get_dump(fmap=fmap) xgdump = self.get_dump(fmap=fmap)
values = [] values = []
# pylint: disable=consider-using-f-string
regexp = re.compile(r"\[{0}<([\d.Ee+-]+)\]".format(feature)) regexp = re.compile(r"\[{0}<([\d.Ee+-]+)\]".format(feature))
for i, _ in enumerate(xgdump): for i, _ in enumerate(xgdump):
m = re.findall(regexp, xgdump[i]) m = re.findall(regexp, xgdump[i])
@ -2514,7 +2523,7 @@ class Booster(object):
fn = self.feature_names fn = self.feature_names
if fn is None: if fn is None:
# Let xgboost generate the feature names. # Let xgboost generate the feature names.
fn = ["f{0}".format(i) for i in range(self.num_features())] fn = [f"f{i}" for i in range(self.num_features())]
try: try:
index = fn.index(feature) index = fn.index(feature)
feature_t = ft[index] feature_t = ft[index]

View File

@ -282,7 +282,7 @@ class DaskDMatrix:
if len(data.shape) != 2: if len(data.shape) != 2:
raise ValueError( raise ValueError(
"Expecting 2 dimensional input, got: {shape}".format(shape=data.shape) f"Expecting 2 dimensional input, got: {data.shape}"
) )
if not isinstance(data, (dd.DataFrame, da.Array)): if not isinstance(data, (dd.DataFrame, da.Array)):
@ -328,12 +328,9 @@ class DaskDMatrix:
def inconsistent( def inconsistent(
left: List[Any], left_name: str, right: List[Any], right_name: str left: List[Any], left_name: str, right: List[Any], right_name: str
) -> str: ) -> str:
msg = 'Partitions between {a_name} and {b_name} are not ' \ msg = (f"Partitions between {left_name} and {right_name} are not "
'consistent: {a_len} != {b_len}. ' \ f"consistent: {len(left)} != {len(right)}. "
'Please try to repartition/rechunk your data.'.format( f"Please try to repartition/rechunk your data.")
a_name=left_name, b_name=right_name, a_len=len(left),
b_len=len(right)
)
return msg return msg
def check_columns(parts: Any) -> None: def check_columns(parts: Any) -> None:
@ -683,7 +680,7 @@ def _create_device_quantile_dmatrix(
) -> DeviceQuantileDMatrix: ) -> DeviceQuantileDMatrix:
worker = distributed.get_worker() worker = distributed.get_worker()
if parts is None: if parts is None:
msg = "worker {address} has an empty DMatrix.".format(address=worker.address) msg = f"worker {worker.address} has an empty DMatrix."
LOGGER.warning(msg) LOGGER.warning(msg)
import cupy import cupy
@ -747,7 +744,7 @@ def _create_dmatrix(
worker = distributed.get_worker() worker = distributed.get_worker()
list_of_parts = parts list_of_parts = parts
if list_of_parts is None: if list_of_parts is None:
msg = 'worker {address} has an empty DMatrix. '.format(address=worker.address) msg = f"worker {worker.address} has an empty DMatrix."
LOGGER.warning(msg) LOGGER.warning(msg)
d = DMatrix( d = DMatrix(
numpy.empty((0, 0)), numpy.empty((0, 0)),
@ -806,7 +803,7 @@ def _dmatrix_from_list_of_parts(
async def _get_rabit_args(n_workers: int, client: "distributed.Client") -> List[bytes]: async def _get_rabit_args(n_workers: int, client: "distributed.Client") -> List[bytes]:
'''Get rabit context arguments from data distribution in DaskDMatrix.''' '''Get rabit context arguments from data distribution in DaskDMatrix.'''
env = await client.run_on_scheduler(_start_tracker, n_workers) env = await client.run_on_scheduler(_start_tracker, n_workers)
rabit_args = [('%s=%s' % item).encode() for item in env.items()] rabit_args = [f"{k}={v}".encode() for k, v in env.items()]
return rabit_args return rabit_args
# train and predict methods are supposed to be "functional", which meets the # train and predict methods are supposed to be "functional", which meets the

View File

@ -69,7 +69,7 @@ def _from_scipy_csr(
"""Initialize data from a CSR matrix.""" """Initialize data from a CSR matrix."""
if len(data.indices) != len(data.data): if len(data.indices) != len(data.data):
raise ValueError( raise ValueError(
"length mismatch: {} vs {}".format(len(data.indices), len(data.data)) f"length mismatch: {len(data.indices)} vs {len(data.data)}"
) )
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
args = { args = {
@ -106,8 +106,7 @@ def _from_scipy_csc(
feature_types: Optional[List[str]], feature_types: Optional[List[str]],
): ):
if len(data.indices) != len(data.data): if len(data.indices) != len(data.data):
raise ValueError('length mismatch: {} vs {}'.format( raise ValueError(f"length mismatch: {len(data.indices)} vs {len(data.data)}")
len(data.indices), len(data.data)))
_warn_unused_missing(data, missing) _warn_unused_missing(data, missing)
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromCSCEx( _check_call(_LIB.XGDMatrixCreateFromCSCEx(
@ -277,8 +276,7 @@ def _transform_pandas_df(
if meta and len(data.columns) > 1: if meta and len(data.columns) > 1:
raise ValueError( raise ValueError(
'DataFrame for {meta} cannot have multiple columns'.format( f"DataFrame for {meta} cannot have multiple columns"
meta=meta)
) )
dtype = meta_type if meta_type else np.float32 dtype = meta_type if meta_type else np.float32

View File

@ -184,7 +184,7 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir=None,
kwargs['graph_attrs'] = {} kwargs['graph_attrs'] = {}
kwargs['graph_attrs']['rankdir'] = rankdir kwargs['graph_attrs']['rankdir'] = rankdir
for key, value in extra.items(): for key, value in extra.items():
if 'graph_attrs' in kwargs.keys(): if kwargs.get("graph_attrs", None) is not None:
kwargs['graph_attrs'][key] = value kwargs['graph_attrs'][key] = value
else: else:
kwargs['graph_attrs'] = {} kwargs['graph_attrs'] = {}

View File

@ -185,7 +185,7 @@ def allreduce(data, op, prepare_fun=None):
if buf.base is data.base: if buf.base is data.base:
buf = buf.copy() buf = buf.copy()
if buf.dtype not in DTYPE_ENUM__: if buf.dtype not in DTYPE_ENUM__:
raise Exception('data type %s not supported' % str(buf.dtype)) raise Exception(f"data type {buf.dtype} not supported")
if prepare_fun is None: if prepare_fun is None:
_check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), _check_call(_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype], buf.size, DTYPE_ENUM__[buf.dtype],

View File

@ -335,7 +335,7 @@ def _wrap_evaluation_matrices(
) )
evals.append(m) evals.append(m)
nevals = len(evals) nevals = len(evals)
eval_names = ["validation_{}".format(i) for i in range(nevals)] eval_names = [f"validation_{i}" for i in range(nevals)]
evals = list(zip(evals, eval_names)) evals = list(zip(evals, eval_names))
else: else:
if any( if any(
@ -526,7 +526,7 @@ class XGBModel(XGBModelBase):
stack.append(v) stack.append(v)
for k, v in internal.items(): for k, v in internal.items():
if k in params.keys() and params[k] is None: if k in params and params[k] is None:
params[k] = parse_parameter(v) params[k] = parse_parameter(v)
except ValueError: except ValueError:
pass pass
@ -1016,7 +1016,7 @@ class XGBModel(XGBModelBase):
importance_type=self.importance_type if self.importance_type else dft() importance_type=self.importance_type if self.importance_type else dft()
) )
if b.feature_names is None: if b.feature_names is None:
feature_names = ["f{0}".format(i) for i in range(self.n_features_in_)] feature_names = [f"f{i}" for i in range(self.n_features_in_)]
else: else:
feature_names = b.feature_names feature_names = b.feature_names
# gblinear returns all features so the `get` in next line is only for gbtree. # gblinear returns all features so the `get` in next line is only for gbtree.
@ -1044,8 +1044,8 @@ class XGBModel(XGBModelBase):
""" """
if self.get_params()['booster'] != 'gblinear': if self.get_params()['booster'] != 'gblinear':
raise AttributeError( raise AttributeError(
'Coefficients are not defined for Booster type {}' f"Coefficients are not defined for Booster type {self.booster}"
.format(self.booster)) )
b = self.get_booster() b = self.get_booster()
coef = np.array(json.loads(b.get_dump(dump_format='json')[0])['weight']) coef = np.array(json.loads(b.get_dump(dump_format='json')[0])['weight'])
# Logic for multiclass classification # Logic for multiclass classification
@ -1074,8 +1074,8 @@ class XGBModel(XGBModelBase):
""" """
if self.get_params()['booster'] != 'gblinear': if self.get_params()['booster'] != 'gblinear':
raise AttributeError( raise AttributeError(
'Intercept (bias) is not defined for Booster type {}' f"Intercept (bias) is not defined for Booster type {self.booster}"
.format(self.booster)) )
b = self.get_booster() b = self.get_booster()
return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias']) return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias'])

View File

@ -64,7 +64,7 @@ class SlaveEntry(object):
self.sock = slave self.sock = slave
self.host = get_some_ip(s_addr[0]) self.host = get_some_ip(s_addr[0])
magic = slave.recvint() magic = slave.recvint()
assert magic == kMagic, 'invalid magic number=%d from %s' % (magic, self.host) assert magic == kMagic, f"invalid magic number={magic} from {self.host}"
slave.sendint(kMagic) slave.sendint(kMagic)
self.rank = slave.recvint() self.rank = slave.recvint()
self.world_size = slave.recvint() self.world_size = slave.recvint()
@ -296,7 +296,7 @@ class RabitTracker(object):
shutdown[s.rank] = s shutdown[s.rank] = s
logging.debug('Received %s signal from %d', s.cmd, s.rank) logging.debug('Received %s signal from %d', s.cmd, s.rank)
continue continue
assert s.cmd == 'start' or s.cmd == 'recover' assert s.cmd in ("start", "recover")
# lazily initialize the slaves # lazily initialize the slaves
if tree_map is None: if tree_map is None:
assert s.cmd == 'start' assert s.cmd == 'start'
@ -306,7 +306,7 @@ class RabitTracker(object):
# set of nodes that is pending for getting up # set of nodes that is pending for getting up
todo_nodes = list(range(nslave)) todo_nodes = list(range(nslave))
else: else:
assert s.world_size == -1 or s.world_size == nslave assert s.world_size in (-1, nslave)
if s.cmd == 'recover': if s.cmd == 'recover':
assert s.rank >= 0 assert s.rank >= 0
@ -392,7 +392,7 @@ def start_rabit_tracker(args):
sys.stdout.write('DMLC_TRACKER_ENV_START\n') sys.stdout.write('DMLC_TRACKER_ENV_START\n')
# simply write configuration to stdout # simply write configuration to stdout
for k, v in envs.items(): for k, v in envs.items():
sys.stdout.write('%s=%s\n' % (k, str(v))) sys.stdout.write(f"{k}={v}\n")
sys.stdout.write('DMLC_TRACKER_ENV_END\n') sys.stdout.write('DMLC_TRACKER_ENV_END\n')
sys.stdout.flush() sys.stdout.flush()
rabit.join() rabit.join()
@ -419,7 +419,7 @@ def main():
elif args.log_level == 'DEBUG': elif args.log_level == 'DEBUG':
level = logging.DEBUG level = logging.DEBUG
else: else:
raise RuntimeError("Unknown logging level %s" % args.log_level) raise RuntimeError(f"Unknown logging level {args.log_level}")
logging.basicConfig(format=fmt, level=level) logging.basicConfig(format=fmt, level=level)