From c6769488b3a615092b60ba351a60c07b95ec85df Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 28 Oct 2021 20:47:04 +0800 Subject: [PATCH] Typehint for subset of core API. (#7348) --- python-package/xgboost/callback.py | 197 +++++++++------- python-package/xgboost/core.py | 226 ++++++++++--------- python-package/xgboost/dask.py | 105 ++++----- python-package/xgboost/sklearn.py | 171 +++++--------- python-package/xgboost/tracker.py | 2 +- python-package/xgboost/training.py | 107 +++++---- tests/python/test_interaction_constraints.py | 4 +- 7 files changed, 400 insertions(+), 412 deletions(-) diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py index 7552db79d..53a1c4e9a 100644 --- a/python-package/xgboost/callback.py +++ b/python-package/xgboost/callback.py @@ -6,17 +6,18 @@ from abc import ABC import collections import os import pickle -from typing import Callable, List, Optional, Union, Dict, Tuple +from typing import Callable, List, Optional, Union, Dict, Tuple, TypeVar, cast +from typing import Sequence import numpy from . import rabit -from .core import Booster, XGBoostError +from .core import Booster, DMatrix, XGBoostError from .compat import STRING_TYPES -# The new implementation of callback functions. -# Breaking: -# - reset learning rate no longer accepts total boosting rounds +_Score = Union[float, Tuple[float, float]] +_ScoreList = Union[List[float], List[Tuple[float, float]]] + # pylint: disable=unused-argument class TrainingCallback(ABC): @@ -26,9 +27,9 @@ class TrainingCallback(ABC): ''' - EvalsLog = Dict[str, Dict[str, Union[List[float], List[Tuple[float, float]]]]] + EvalsLog = Dict[str, Dict[str, _ScoreList]] - def __init__(self): + def __init__(self) -> None: pass def before_training(self, model): @@ -48,18 +49,18 @@ class TrainingCallback(ABC): return False -def _aggcv(rlist): - # pylint: disable=invalid-name +def _aggcv(rlist: List[str]) -> List[Tuple[str, float, float]]: + # pylint: disable=invalid-name, too-many-locals """Aggregate cross-validation results. """ - cvmap = {} + cvmap: Dict[Tuple[int, str], List[float]] = {} idx = rlist[0].split()[0] for line in rlist: - arr = line.split() + arr: List[str] = line.split() assert idx == arr[0] for metric_idx, it in enumerate(arr[1:]): - if not isinstance(it, STRING_TYPES): + if not isinstance(it, str): it = it.decode() k, v = it.split(':') if (metric_idx, k) not in cvmap: @@ -67,16 +68,20 @@ def _aggcv(rlist): cvmap[(metric_idx, k)].append(float(v)) msg = idx results = [] - for (metric_idx, k), v in sorted(cvmap.items(), key=lambda x: x[0][0]): - v = numpy.array(v) + for (_, name), s in sorted(cvmap.items(), key=lambda x: x[0][0]): + as_arr = numpy.array(s) if not isinstance(msg, STRING_TYPES): msg = msg.decode() - mean, std = numpy.mean(v), numpy.std(v) - results.extend([(k, mean, std)]) + mean, std = numpy.mean(as_arr), numpy.std(as_arr) + results.extend([(name, mean, std)]) return results -def _allreduce_metric(score): +# allreduce type +_ART = TypeVar("_ART") + + +def _allreduce_metric(score: _ART) -> _ART: '''Helper function for computing customized metric in distributed environment. Not strictly correct as many functions don't use mean value as final result. @@ -89,13 +94,13 @@ def _allreduce_metric(score): if isinstance(score, tuple): # has mean and stdv raise ValueError( 'xgboost.cv function should not be used in distributed environment.') - score = numpy.array([score]) - score = rabit.allreduce(score, rabit.Op.SUM) / world - return score[0] + arr = numpy.array([score]) + arr = rabit.allreduce(arr, rabit.Op.SUM) / world + return arr[0] class CallbackContainer: - '''A special callback for invoking a list of other callbacks. + '''A special internal callback for invoking a list of other callbacks. .. versionadded:: 1.3.0 @@ -105,7 +110,7 @@ class CallbackContainer: def __init__( self, - callbacks: List[TrainingCallback], + callbacks: Sequence[TrainingCallback], metric: Callable = None, output_margin: bool = True, is_cv: bool = False @@ -146,33 +151,50 @@ class CallbackContainer: assert isinstance(model, Booster), msg return model - def before_iteration(self, model, epoch, dtrain, evals) -> bool: + def before_iteration( + self, model, epoch: int, dtrain: DMatrix, evals: List[Tuple[DMatrix, str]] + ) -> bool: '''Function called before training iteration.''' return any(c.before_iteration(model, epoch, self.history) for c in self.callbacks) - def _update_history(self, score, epoch): + def _update_history( + self, + score: Union[List[Tuple[str, float]], List[Tuple[str, float, float]]], + epoch: int + ) -> None: for d in score: - name, s = d[0], float(d[1]) + name: str = d[0] + s: float = d[1] if self.is_cv: - std = float(d[2]) - s = (s, std) + std = float(cast(Tuple[str, float, float], d)[2]) + x: _Score = (s, std) + else: + x = s splited_names = name.split('-') data_name = splited_names[0] metric_name = '-'.join(splited_names[1:]) - s = _allreduce_metric(s) - if data_name in self.history: - data_history = self.history[data_name] - if metric_name in data_history: - data_history[metric_name].append(s) - else: - data_history[metric_name] = [s] - else: + x = _allreduce_metric(x) + if data_name not in self.history: self.history[data_name] = collections.OrderedDict() - self.history[data_name][metric_name] = [s] - return False + data_history = self.history[data_name] + if metric_name not in data_history: + data_history[metric_name] = cast(_ScoreList, []) + metric_history = data_history[metric_name] + if self.is_cv: + cast(List[Tuple[float, float]], metric_history).append( + cast(Tuple[float, float], x) + ) + else: + cast(List[float], metric_history).append(cast(float, x)) - def after_iteration(self, model, epoch, dtrain, evals) -> bool: + def after_iteration( + self, + model, + epoch: int, + dtrain: DMatrix, + evals: Optional[List[Tuple[DMatrix, str]]], + ) -> bool: '''Function called after training iteration.''' if self.is_cv: scores = model.eval(epoch, self.metric, self._output_margin) @@ -183,18 +205,20 @@ class CallbackContainer: evals = [] if evals is None else evals for _, name in evals: assert name.find('-') == -1, 'Dataset name should not contain `-`' - score = model.eval_set(evals, epoch, self.metric, self._output_margin) - score = score.split()[1:] # into datasets + score: str = model.eval_set(evals, epoch, self.metric, self._output_margin) + splited = score.split()[1:] # into datasets # split up `test-error:0.1234` - score = [tuple(s.split(':')) for s in score] - self._update_history(score, epoch) + metric_score_str = [tuple(s.split(':')) for s in splited] + # convert to float + metric_score = [(n, float(s)) for n, s in metric_score_str] + self._update_history(metric_score, epoch) ret = any(c.after_iteration(model, epoch, self.history) for c in self.callbacks) return ret class LearningRateScheduler(TrainingCallback): - '''Callback function for scheduling learning rate. + """Callback function for scheduling learning rate. .. versionadded:: 1.3.0 @@ -207,18 +231,24 @@ class LearningRateScheduler(TrainingCallback): should be a sequence like list or tuple with the same size of boosting rounds. - ''' - def __init__(self, learning_rates) -> None: - assert callable(learning_rates) or \ - isinstance(learning_rates, collections.abc.Sequence) + """ + + def __init__( + self, learning_rates: Union[Callable[[int], float], Sequence[float]] + ) -> None: + assert callable(learning_rates) or isinstance( + learning_rates, collections.abc.Sequence + ) if callable(learning_rates): self.learning_rates = learning_rates else: - self.learning_rates = lambda epoch: learning_rates[epoch] + self.learning_rates = lambda epoch: cast(Sequence, learning_rates)[epoch] super().__init__() - def after_iteration(self, model, epoch, evals_log) -> bool: - model.set_param('learning_rate', self.learning_rates(epoch)) + def after_iteration( + self, model, epoch: int, evals_log: TrainingCallback.EvalsLog + ) -> bool: + model.set_param("learning_rate", self.learning_rates(epoch)) return False @@ -230,17 +260,17 @@ class EarlyStopping(TrainingCallback): Parameters ---------- - rounds + rounds : Early stopping rounds. - metric_name + metric_name : Name of metric that is used for early stopping. - data_name + data_name : Name of dataset that is used for early stopping. - maximize + maximize : Whether to maximize evaluation metric. None means auto (discouraged). - save_best + save_best : Whether training should return the best model or the last model. - min_delta + min_delta : Minimum absolute change in score to be qualified as an improvement. .. versionadded:: 1.5.0 @@ -279,8 +309,6 @@ class EarlyStopping(TrainingCallback): if self._min_delta < 0: raise ValueError("min_delta must be greater or equal to 0.") - self.improve_op = None - self.current_rounds: int = 0 self.best_scores: dict = {} self.starting_round: int = 0 @@ -290,16 +318,18 @@ class EarlyStopping(TrainingCallback): self.starting_round = model.num_boosted_rounds() return model - def _update_rounds(self, score, name, metric, model, epoch) -> bool: - def get_s(x): + def _update_rounds( + self, score: _Score, name: str, metric: str, model, epoch: int + ) -> bool: + def get_s(x: _Score) -> float: """get score if it's cross validation history.""" return x[0] if isinstance(x, tuple) else x - def maximize(new, best): + def maximize(new: _Score, best: _Score) -> bool: """New score should be greater than the old one.""" return numpy.greater(get_s(new) - self._min_delta, get_s(best)) - def minimize(new, best): + def minimize(new: _Score, best: _Score) -> bool: """New score should be smaller than the old one.""" return numpy.greater(get_s(best) - self._min_delta, get_s(new)) @@ -314,25 +344,25 @@ class EarlyStopping(TrainingCallback): self.maximize = False if self.maximize: - self.improve_op = maximize + improve_op = maximize else: - self.improve_op = minimize + improve_op = minimize - assert self.improve_op + assert improve_op if not self.stopping_history: # First round self.current_rounds = 0 self.stopping_history[name] = {} - self.stopping_history[name][metric] = [score] + self.stopping_history[name][metric] = cast(_ScoreList, [score]) self.best_scores[name] = {} self.best_scores[name][metric] = [score] model.set_attr(best_score=str(score), best_iteration=str(epoch)) - elif not self.improve_op(score, self.best_scores[name][metric][-1]): + elif not improve_op(score, self.best_scores[name][metric][-1]): # Not improved - self.stopping_history[name][metric].append(score) + self.stopping_history[name][metric].append(score) # type: ignore self.current_rounds += 1 else: # Improved - self.stopping_history[name][metric].append(score) + self.stopping_history[name][metric].append(score) # type: ignore self.best_scores[name][metric].append(score) record = self.stopping_history[name][metric][-1] model.set_attr(best_score=str(record), best_iteration=str(epoch)) @@ -390,16 +420,16 @@ class EvaluationMonitor(TrainingCallback): Parameters ---------- - metric : callable + metric : Extra user defined metric. - rank : int + rank : Which worker should be used for printing the result. - period : int + period : How many epoches between printing. - show_stdv : bool + show_stdv : Used in cv to show standard deviation. Users should not specify it. ''' - def __init__(self, rank=0, period=1, show_stdv=False) -> None: + def __init__(self, rank: int = 0, period: int = 1, show_stdv: bool = False) -> None: self.printer_rank = rank self.show_stdv = show_stdv self.period = period @@ -457,22 +487,27 @@ class TrainingCheckPoint(TrainingCallback): Parameters ---------- - directory : os.PathLike + directory : Output model directory. - name : str + name : pattern of output model file. Models will be saved as name_0.json, name_1.json, name_2.json .... - as_pickle : boolean + as_pickle : When set to Ture, all training parameters will be saved in pickle format, instead of saving only the model. - iterations : int + iterations : Interval of checkpointing. Checkpointing is slow so setting a larger number can reduce performance hit. ''' - def __init__(self, directory: os.PathLike, name: str = 'model', - as_pickle=False, iterations: int = 100): - self._path = directory + def __init__( + self, + directory: Union[str, os.PathLike], + name: str = 'model', + as_pickle: bool = False, + iterations: int = 100 + ) -> None: + self._path = os.fspath(directory) self._name = name self._as_pickle = as_pickle self._iterations = iterations diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index e53fdb21d..594ed74bb 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -6,7 +6,7 @@ from collections.abc import Mapping from typing import List, Optional, Any, Union, Dict, TypeVar # pylint: enable=no-name-in-module,import-error -from typing import Callable, Tuple, cast +from typing import Callable, Tuple, cast, Sequence import ctypes import os import re @@ -31,20 +31,6 @@ class XGBoostError(ValueError): """Error thrown by xgboost trainer.""" -class EarlyStopException(Exception): - """Exception to signal early stopping. - - Parameters - ---------- - best_iteration : int - The best iteration stopped. - """ - - def __init__(self, best_iteration): - super().__init__() - self.best_iteration = best_iteration - - def from_pystr_to_cstr(data: Union[str, List[str]]): """Convert a Python str or list of Python str to C pointer @@ -132,18 +118,19 @@ def _log_callback(msg: bytes) -> None: print(py_str(msg)) -def _get_log_callback_func(): +def _get_log_callback_func() -> Callable: """Wrap log_callback() method in ctypes callback type""" # pylint: disable=invalid-name CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_char_p) return CALLBACK(_log_callback) -def _load_lib(): +def _load_lib() -> ctypes.CDLL: """Load xgboost Library.""" lib_paths = find_lib_path() if not lib_paths: - return None + # This happens only when building document. + return None # type: ignore try: pathBackup = os.environ['PATH'].split(os.pathsep) except KeyError: @@ -190,7 +177,7 @@ Error message(s): {os_error_list} _LIB = _load_lib() -def _check_call(ret): +def _check_call(ret: int) -> None: """Check the return value of C API call This function will raise exception when error occurs. @@ -234,7 +221,7 @@ def _cuda_array_interface(data) -> bytes: return interface_str -def ctypes2numpy(cptr, length, dtype): +def ctypes2numpy(cptr, length, dtype) -> np.ndarray: """Convert a ctypes pointer array to a numpy array.""" ctype = _numpy2ctypes_type(dtype) if not isinstance(cptr, ctypes.POINTER(ctype)): @@ -271,7 +258,7 @@ def ctypes2cupy(cptr, length, dtype): return arr -def ctypes2buffer(cptr, length): +def ctypes2buffer(cptr, length) -> bytearray: """Convert ctypes pointer to buffer type.""" if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)): raise RuntimeError('expected char pointer') @@ -428,7 +415,7 @@ class DataIter: # pylint: disable=too-many-instance-attributes Parameters ---------- - data_handle: + input_data: A function with same data fields like `data`, `label` with `xgboost.DMatrix`. @@ -627,7 +614,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes if feature_types is not None: self.feature_types = feature_types - def _init_from_iter(self, iterator: DataIter, enable_categorical: bool): + def _init_from_iter(self, iterator: DataIter, enable_categorical: bool) -> None: it = iterator args = { "missing": self.missing, @@ -654,7 +641,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes _check_call(ret) self.handle = handle - def __del__(self): + def __del__(self) -> None: if hasattr(self, "handle") and self.handle: _check_call(_LIB.XGDMatrixFree(self.handle)) self.handle = None @@ -699,7 +686,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes dispatch_meta_backend(matrix=self, data=feature_weights, name='feature_weights') - def get_float_info(self, field): + def get_float_info(self, field: str) -> np.ndarray: """Get float property from the DMatrix. Parameters @@ -720,7 +707,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes ctypes.byref(ret))) return ctypes2numpy(ret, length.value, np.float32) - def get_uint_info(self, field): + def get_uint_info(self, field: str) -> np.ndarray: """Get unsigned integer property from the DMatrix. Parameters @@ -741,7 +728,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes ctypes.byref(ret))) return ctypes2numpy(ret, length.value, np.uint32) - def set_float_info(self, field, data): + def set_float_info(self, field: str, data) -> None: """Set float type property into the DMatrix. Parameters @@ -755,7 +742,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes from .data import dispatch_meta_backend dispatch_meta_backend(self, data, field, 'float') - def set_float_info_npy2d(self, field, data): + def set_float_info_npy2d(self, field: str, data) -> None: """Set float type property into the DMatrix for numpy 2d array input @@ -770,7 +757,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes from .data import dispatch_meta_backend dispatch_meta_backend(self, data, field, 'float') - def set_uint_info(self, field, data): + def set_uint_info(self, field: str, data) -> None: """Set uint type property into the DMatrix. Parameters @@ -784,7 +771,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes from .data import dispatch_meta_backend dispatch_meta_backend(self, data, field, 'uint32') - def save_binary(self, fname, silent=True): + def save_binary(self, fname, silent=True) -> None: """Save DMatrix to an XGBoost buffer. Saved binary can be later loaded by providing the path to :py:func:`xgboost.DMatrix` as input. @@ -800,7 +787,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes c_str(fname), ctypes.c_int(silent))) - def set_label(self, label): + def set_label(self, label) -> None: """Set label of dmatrix Parameters @@ -811,7 +798,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes from .data import dispatch_meta_backend dispatch_meta_backend(self, label, 'label', 'float') - def set_weight(self, weight): + def set_weight(self, weight) -> None: """Set weight of each instance. Parameters @@ -830,7 +817,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes from .data import dispatch_meta_backend dispatch_meta_backend(self, weight, 'weight', 'float') - def set_base_margin(self, margin): + def set_base_margin(self, margin) -> None: """Set base margin of booster to start from. This can be used to specify a prediction value of existing model to be @@ -847,7 +834,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes from .data import dispatch_meta_backend dispatch_meta_backend(self, margin, 'base_margin', 'float') - def set_group(self, group): + def set_group(self, group) -> None: """Set group size of DMatrix (used for ranking). Parameters @@ -858,7 +845,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes from .data import dispatch_meta_backend dispatch_meta_backend(self, group, 'group', 'uint32') - def get_label(self): + def get_label(self) -> np.ndarray: """Get the label of the DMatrix. Returns @@ -867,7 +854,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes """ return self.get_float_info('label') - def get_weight(self): + def get_weight(self) -> np.ndarray: """Get the weight of the DMatrix. Returns @@ -876,7 +863,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes """ return self.get_float_info('weight') - def get_base_margin(self): + def get_base_margin(self) -> np.ndarray: """Get the base margin of the DMatrix. Returns @@ -885,7 +872,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes """ return self.get_float_info('base_margin') - def num_row(self): + def num_row(self) -> int: """Get the number of rows in the DMatrix. Returns @@ -897,7 +884,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes ctypes.byref(ret))) return ret.value - def num_col(self): + def num_col(self) -> int: """Get the number of columns (features) in the DMatrix. Returns @@ -1191,7 +1178,7 @@ class DeviceQuantileDMatrix(DMatrix): enable_categorical=enable_categorical, ) - def _init(self, data, enable_categorical, **meta): + def _init(self, data, enable_categorical: bool, **meta) -> None: from .data import ( _is_dlpack, _transform_dlpack, @@ -1265,7 +1252,7 @@ def _get_booster_layer_trees(model: "Booster") -> Tuple[int, int]: return num_parallel_tree, num_groups -class Booster(object): +class Booster: # pylint: disable=too-many-public-methods """A Booster of XGBoost. @@ -1273,7 +1260,12 @@ class Booster(object): training, prediction and evaluation. """ - def __init__(self, params=None, cache=(), model_file=None): + def __init__( + self, + params: Optional[Dict] = None, + cache: Optional[Sequence[DMatrix]] = None, + model_file: Optional[Union["Booster", bytearray, os.PathLike, str]] = None + ) -> None: # pylint: disable=invalid-name """ Parameters @@ -1285,12 +1277,13 @@ class Booster(object): model_file : string/os.PathLike/Booster/bytearray Path to the model file if it's string or PathLike. """ + cache = cache if cache is not None else [] for d in cache: if not isinstance(d, DMatrix): raise TypeError(f'invalid cache item: {type(d).__name__}', cache) dmats = c_array(ctypes.c_void_p, [d.handle for d in cache]) - self.handle = ctypes.c_void_p() + self.handle: Optional[ctypes.c_void_p] = ctypes.c_void_p() _check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(len(cache)), ctypes.byref(self.handle))) for d in cache: @@ -1405,12 +1398,12 @@ class Booster(object): return params - def __del__(self): + def __del__(self) -> None: if hasattr(self, 'handle') and self.handle is not None: _check_call(_LIB.XGBoosterFree(self.handle)) self.handle = None - def __getstate__(self): + def __getstate__(self) -> Dict: # can't pickle ctypes pointers, put model content in bytearray this = self.__dict__.copy() handle = this['handle'] @@ -1424,7 +1417,7 @@ class Booster(object): this["handle"] = buf return this - def __setstate__(self, state): + def __setstate__(self, state: Dict) -> None: # reconstruct handle from raw data handle = state['handle'] if handle is not None: @@ -1440,7 +1433,7 @@ class Booster(object): state['handle'] = handle self.__dict__.update(state) - def __getitem__(self, val): + def __getitem__(self, val: Union[int, tuple, slice]) -> "Booster": if isinstance(val, int): val = slice(val, val+1) if isinstance(val, tuple): @@ -1461,13 +1454,14 @@ class Booster(object): step = val.step if val.step is not None else 1 - start = ctypes.c_int(start) - stop = ctypes.c_int(stop) - step = ctypes.c_int(step) + c_start = ctypes.c_int(start) + c_stop = ctypes.c_int(stop) + c_step = ctypes.c_int(step) sliced_handle = ctypes.c_void_p() - status = _LIB.XGBoosterSlice(self.handle, start, stop, step, - ctypes.byref(sliced_handle)) + status = _LIB.XGBoosterSlice( + self.handle, c_start, c_stop, c_step, ctypes.byref(sliced_handle) + ) if status == -2: raise IndexError('Layer index out of range') _check_call(status) @@ -1477,7 +1471,7 @@ class Booster(object): sliced.handle = sliced_handle return sliced - def save_config(self): + def save_config(self) -> str: '''Output internal parameter configuration of Booster as a JSON string. @@ -1489,10 +1483,11 @@ class Booster(object): self.handle, ctypes.byref(length), ctypes.byref(json_string))) - json_string = json_string.value.decode() # pylint: disable=no-member - return json_string + assert json_string.value is not None + result = json_string.value.decode() # pylint: disable=no-member + return result - def load_config(self, config): + def load_config(self, config: str) -> None: '''Load configuration returned by `save_config`. .. versionadded:: 1.0.0 @@ -1502,14 +1497,14 @@ class Booster(object): self.handle, c_str(config))) - def __copy__(self): + def __copy__(self) -> "Booster": return self.__deepcopy__(None) - def __deepcopy__(self, _): + def __deepcopy__(self, _) -> "Booster": '''Return a copy of booster.''' return Booster(model_file=self) - def copy(self): + def copy(self) -> "Booster": """Copy the booster object. Returns @@ -1519,7 +1514,7 @@ class Booster(object): """ return self.__copy__() - def attr(self, key): + def attr(self, key: str) -> Optional[str]: """Get attribute string from the Booster. Parameters @@ -1540,7 +1535,7 @@ class Booster(object): return py_str(ret.value) return None - def attributes(self): + def attributes(self) -> Dict[str, str]: """Get attributes stored in the Booster as a dictionary. Returns @@ -1572,7 +1567,7 @@ class Booster(object): _check_call(_LIB.XGBoosterSetAttr( self.handle, c_str(key), value)) - def _get_feature_info(self, field: str): + def _get_feature_info(self, field: str) -> Optional[List[str]]: length = c_bst_ulong() sarr = ctypes.POINTER(ctypes.c_char_p)() if not hasattr(self, "handle") or self.handle is None: @@ -1585,22 +1580,6 @@ class Booster(object): feature_info = from_cstr_to_pystr(sarr, length) return feature_info if feature_info else None - @property - def feature_types(self) -> Optional[List[str]]: - """Feature types for this booster. Can be directly set by input data or by - assignment. - - """ - return self._get_feature_info("feature_type") - - @property - def feature_names(self) -> Optional[List[str]]: - """Feature names for this booster. Can be directly set by input data or by - assignment. - - """ - return self._get_feature_info("feature_name") - def _set_feature_info(self, features: Optional[List[str]], field: str) -> None: if features is not None: assert isinstance(features, list) @@ -1618,14 +1597,30 @@ class Booster(object): ) ) - @feature_names.setter - def feature_names(self, features: Optional[List[str]]) -> None: - self._set_feature_info(features, "feature_name") + @property + def feature_types(self) -> Optional[List[str]]: + """Feature types for this booster. Can be directly set by input data or by + assignment. + + """ + return self._get_feature_info("feature_type") @feature_types.setter def feature_types(self, features: Optional[List[str]]) -> None: self._set_feature_info(features, "feature_type") + @property + def feature_names(self) -> Optional[List[str]]: + """Feature names for this booster. Can be directly set by input data or by + assignment. + + """ + return self._get_feature_info("feature_name") + + @feature_names.setter + def feature_names(self, features: Optional[List[str]]) -> None: + self._set_feature_info(features, "feature_name") + def set_param(self, params, value=None): """Set parameters into the Booster. @@ -1645,7 +1640,9 @@ class Booster(object): _check_call(_LIB.XGBoosterSetParam(self.handle, c_str(key), c_str(str(val)))) - def update(self, dtrain, iteration, fobj=None): + def update( + self, dtrain: DMatrix, iteration: int, fobj: Optional[Objective] = None + ) -> None: """Update for one iteration, with objective function calculated internally. This function should not be called directly by users. @@ -1672,18 +1669,18 @@ class Booster(object): grad, hess = fobj(pred, dtrain) self.boost(dtrain, grad, hess) - def boost(self, dtrain, grad, hess): + def boost(self, dtrain: DMatrix, grad: np.ndarray, hess: np.ndarray) -> None: """Boost the booster for one iteration, with customized gradient statistics. Like :py:func:`xgboost.Booster.update`, this function should not be called directly by users. Parameters ---------- - dtrain : DMatrix + dtrain : The training DMatrix. - grad : list + grad : The first order of gradient. - hess : list + hess : The second order of gradient. """ @@ -1700,17 +1697,23 @@ class Booster(object): c_array(ctypes.c_float, hess), c_bst_ulong(len(grad)))) - def eval_set(self, evals, iteration=0, feval=None, output_margin=True): + def eval_set( + self, + evals: Sequence[Tuple[DMatrix, str]], + iteration: int = 0, + feval: Optional[Metric] = None, + output_margin: bool = True + ) -> str: # pylint: disable=invalid-name """Evaluate a set of data. Parameters ---------- - evals : list of tuples (DMatrix, string) + evals : List of items to be evaluated. - iteration : int + iteration : Current iteration. - feval : function + feval : Custom evaluation function. Returns @@ -1738,6 +1741,7 @@ class Booster(object): ctypes.byref(msg), ) ) + assert msg.value is not None res = msg.value.decode() # pylint: disable=no-member if feval is not None: for dmat, evname in evals: @@ -1754,18 +1758,18 @@ class Booster(object): res += "\t%s-%s:%f" % (evname, name, val) return res - def eval(self, data, name='eval', iteration=0): + def eval(self, data: DMatrix, name: str = 'eval', iteration: int = 0) -> str: """Evaluate the model on mat. Parameters ---------- - data : DMatrix + data : The dmatrix storing the input. - name : str, optional + name : The name of the dataset. - iteration : int, optional + iteration : The current iteration number. Returns @@ -2101,7 +2105,7 @@ class Booster(object): "Data type:" + str(type(data)) + " not supported by inplace prediction." ) - def save_model(self, fname: Union[str, os.PathLike]): + def save_model(self, fname: Union[str, os.PathLike]) -> None: """Save the model to a file. The model is saved in an XGBoost internal format which is universal among the @@ -2124,7 +2128,7 @@ class Booster(object): else: raise TypeError("fname must be a string or os PathLike") - def save_raw(self): + def save_raw(self) -> bytearray: """Save the model to a in memory buffer representation instead of file. Returns @@ -2232,18 +2236,23 @@ class Booster(object): if need_close: fout.close() - def get_dump(self, fmap='', with_stats=False, dump_format="text"): + def get_dump( + self, + fmap: Union[str, os.PathLike] = "", + with_stats: bool = False, + dump_format: str = "text" + ) -> List[str]: """Returns the model dump as a list of strings. Unlike `save_model`, the output format is primarily used for visualization or interpretation, hence it's more human readable but cannot be loaded back to XGBoost. Parameters ---------- - fmap : string or os.PathLike, optional + fmap : Name of the file containing feature map names. - with_stats : bool, optional + with_stats : Controls whether the split statistics are output. - dump_format : string, optional + dump_format : Format of model dump. Can be 'text', 'json' or 'dot'. """ @@ -2259,7 +2268,9 @@ class Booster(object): res = from_cstr_to_pystr(sarr, length) return res - def get_fscore(self, fmap=''): + def get_fscore( + self, fmap: Union[str, os.PathLike] = "" + ) -> Dict[str, Union[float, List[float]]]: """Get feature importance of each feature. .. note:: Zero-importance features will not be included @@ -2269,7 +2280,7 @@ class Booster(object): Parameters ---------- - fmap: str or os.PathLike (optional) + fmap : The name of feature map file """ @@ -2299,9 +2310,9 @@ class Booster(object): Parameters ---------- - fmap: str or os.PathLike (optional) + fmap: The name of feature map file. - importance_type: str, default 'weight' + importance_type: One of the importance types defined above. Returns @@ -2343,7 +2354,8 @@ class Booster(object): results[feat] = float(score) return results - def trees_to_dataframe(self, fmap=''): # pylint: disable=too-many-statements + # pylint: disable=too-many-statements + def trees_to_dataframe(self, fmap: Union[str, os.PathLike] = '') -> DataFrame: """Parse a boosted tree model text dump into a pandas DataFrame structure. This feature is only defined when the decision tree model is chosen as base @@ -2370,7 +2382,7 @@ class Booster(object): node_ids = [] fids = [] splits = [] - categories = [] + categories: List[Optional[float]] = [] y_directs = [] n_directs = [] missings = [] @@ -2444,7 +2456,7 @@ class Booster(object): # pylint: disable=no-member return df.sort(['Tree', 'Node']).reset_index(drop=True) - def _validate_features(self, data: DMatrix): + def _validate_features(self, data: DMatrix) -> None: """ Validate Booster and data's feature_names are identical. Set feature_names and feature_types from DMatrix diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index e96a21a12..cb103b194 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -17,12 +17,13 @@ https://github.com/dask/dask-xgboost """ import platform import logging +import collections from contextlib import contextmanager from collections import defaultdict -from collections.abc import Sequence from threading import Thread from functools import partial, update_wrapper from typing import TYPE_CHECKING, List, Tuple, Callable, Optional, Any, Union, Dict, Set +from typing import Sequence from typing import Awaitable, Generator, TypeVar import numpy @@ -524,9 +525,9 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902 self._feature_names = feature_names self._feature_types = feature_types - assert isinstance(self._data, Sequence) + assert isinstance(self._data, collections.abc.Sequence) - types = (Sequence, type(None)) + types = (collections.abc.Sequence, type(None)) assert isinstance(self._labels, types) assert isinstance(self._weights, types) assert isinstance(self._base_margin, types) @@ -817,7 +818,7 @@ async def _get_rabit_args(n_workers: int, client: "distributed.Client") -> List[ def _get_workers_from_data( dtrain: DaskDMatrix, - evals: Optional[List[Tuple[DaskDMatrix, str]]] + evals: Optional[Sequence[Tuple[DaskDMatrix, str]]] ) -> List[str]: X_worker_map: Set[str] = set(dtrain.worker_map.keys()) if evals: @@ -837,13 +838,13 @@ async def _train_async( params: Dict[str, Any], dtrain: DaskDMatrix, num_boost_round: int, - evals: Optional[List[Tuple[DaskDMatrix, str]]], + evals: Optional[Sequence[Tuple[DaskDMatrix, str]]], obj: Optional[Objective], feval: Optional[Metric], early_stopping_rounds: Optional[int], verbose_eval: Union[int, bool], xgb_model: Optional[Booster], - callbacks: Optional[List[TrainingCallback]], + callbacks: Optional[Sequence[TrainingCallback]], custom_metric: Optional[Metric], ) -> Optional[TrainReturnT]: workers = _get_workers_from_data(dtrain, evals) @@ -951,13 +952,13 @@ def train( # pylint: disable=unused-argument dtrain: DaskDMatrix, num_boost_round: int = 10, *, - evals: Optional[List[Tuple[DaskDMatrix, str]]] = None, + evals: Optional[Sequence[Tuple[DaskDMatrix, str]]] = None, obj: Optional[Objective] = None, feval: Optional[Metric] = None, early_stopping_rounds: Optional[int] = None, xgb_model: Optional[Booster] = None, verbose_eval: Union[int, bool] = True, - callbacks: Optional[List[TrainingCallback]] = None, + callbacks: Optional[Sequence[TrainingCallback]] = None, custom_metric: Optional[Metric] = None, ) -> Any: """Train XGBoost model. @@ -1648,15 +1649,15 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase): y: _DaskCollection, sample_weight: Optional[_DaskCollection], base_margin: Optional[_DaskCollection], - eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]], - eval_metric: Optional[Union[str, List[str], Metric]], - sample_weight_eval_set: Optional[List[_DaskCollection]], - base_margin_eval_set: Optional[List[_DaskCollection]], + eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]], + eval_metric: Optional[Union[str, Sequence[str], Metric]], + sample_weight_eval_set: Optional[Sequence[_DaskCollection]], + base_margin_eval_set: Optional[Sequence[_DaskCollection]], early_stopping_rounds: Optional[int], verbose: bool, xgb_model: Optional[Union[Booster, XGBModel]], feature_weights: Optional[_DaskCollection], - callbacks: Optional[List[TrainingCallback]], + callbacks: Optional[Sequence[TrainingCallback]], ) -> _DaskCollection: params = self.get_xgb_params() dtrain, evals = await _async_wrap_evaluation_matrices( @@ -1714,15 +1715,15 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase): *, sample_weight: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None, - eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None, - eval_metric: Optional[Union[str, List[str], Metric]] = None, + eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None, + eval_metric: Optional[Union[str, Sequence[str], Callable]] = None, early_stopping_rounds: Optional[int] = None, verbose: bool = True, xgb_model: Optional[Union[Booster, XGBModel]] = None, - sample_weight_eval_set: Optional[List[_DaskCollection]] = None, - base_margin_eval_set: Optional[List[_DaskCollection]] = None, + sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None, + base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None, feature_weights: Optional[_DaskCollection] = None, - callbacks: Optional[List[TrainingCallback]] = None, + callbacks: Optional[Sequence[TrainingCallback]] = None, ) -> "DaskXGBRegressor": _assert_dask_support() args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} @@ -1738,15 +1739,15 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): self, X: _DaskCollection, y: _DaskCollection, sample_weight: Optional[_DaskCollection], base_margin: Optional[_DaskCollection], - eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]], - eval_metric: Optional[Union[str, List[str], Metric]], - sample_weight_eval_set: Optional[List[_DaskCollection]], - base_margin_eval_set: Optional[List[_DaskCollection]], + eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]], + eval_metric: Optional[Union[str, Sequence[str], Metric]], + sample_weight_eval_set: Optional[Sequence[_DaskCollection]], + base_margin_eval_set: Optional[Sequence[_DaskCollection]], early_stopping_rounds: Optional[int], verbose: bool, xgb_model: Optional[Union[Booster, XGBModel]], feature_weights: Optional[_DaskCollection], - callbacks: Optional[List[TrainingCallback]] + callbacks: Optional[Sequence[TrainingCallback]] ) -> "DaskXGBClassifier": params = self.get_xgb_params() dtrain, evals = await _async_wrap_evaluation_matrices( @@ -1818,15 +1819,15 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): *, sample_weight: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None, - eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None, - eval_metric: Optional[Union[str, List[str], Metric]] = None, + eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None, + eval_metric: Optional[Union[str, Sequence[str], Callable]] = None, early_stopping_rounds: Optional[int] = None, verbose: bool = True, xgb_model: Optional[Union[Booster, XGBModel]] = None, - sample_weight_eval_set: Optional[List[_DaskCollection]] = None, - base_margin_eval_set: Optional[List[_DaskCollection]] = None, + sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None, + base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None, feature_weights: Optional[_DaskCollection] = None, - callbacks: Optional[List[TrainingCallback]] = None + callbacks: Optional[Sequence[TrainingCallback]] = None ) -> "DaskXGBClassifier": _assert_dask_support() args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} @@ -1935,17 +1936,17 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn): qid: Optional[_DaskCollection], sample_weight: Optional[_DaskCollection], base_margin: Optional[_DaskCollection], - eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]], - sample_weight_eval_set: Optional[List[_DaskCollection]], - base_margin_eval_set: Optional[List[_DaskCollection]], - eval_group: Optional[List[_DaskCollection]], - eval_qid: Optional[List[_DaskCollection]], - eval_metric: Optional[Union[str, List[str], Metric]], + eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]], + sample_weight_eval_set: Optional[Sequence[_DaskCollection]], + base_margin_eval_set: Optional[Sequence[_DaskCollection]], + eval_group: Optional[Sequence[_DaskCollection]], + eval_qid: Optional[Sequence[_DaskCollection]], + eval_metric: Optional[Union[str, Sequence[str], Metric]], early_stopping_rounds: Optional[int], verbose: bool, xgb_model: Optional[Union[XGBModel, Booster]], feature_weights: Optional[_DaskCollection], - callbacks: Optional[List[TrainingCallback]], + callbacks: Optional[Sequence[TrainingCallback]], ) -> "DaskXGBRanker": msg = "Use `qid` instead of `group` on dask interface." if not (group is None and eval_group is None): @@ -2010,17 +2011,17 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn): qid: Optional[_DaskCollection] = None, sample_weight: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None, - eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None, - eval_group: Optional[List[_DaskCollection]] = None, - eval_qid: Optional[List[_DaskCollection]] = None, - eval_metric: Optional[Union[str, List[str], Metric]] = None, + eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None, + eval_group: Optional[Sequence[_DaskCollection]] = None, + eval_qid: Optional[Sequence[_DaskCollection]] = None, + eval_metric: Optional[Union[str, Sequence[str], Callable]] = None, early_stopping_rounds: int = None, verbose: bool = False, xgb_model: Optional[Union[XGBModel, Booster]] = None, - sample_weight_eval_set: Optional[List[_DaskCollection]] = None, - base_margin_eval_set: Optional[List[_DaskCollection]] = None, + sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None, + base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None, feature_weights: Optional[_DaskCollection] = None, - callbacks: Optional[List[TrainingCallback]] = None + callbacks: Optional[Sequence[TrainingCallback]] = None ) -> "DaskXGBRanker": _assert_dask_support() args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} @@ -2077,15 +2078,15 @@ class DaskXGBRFRegressor(DaskXGBRegressor): *, sample_weight: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None, - eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None, - eval_metric: Optional[Union[str, List[str], Metric]] = None, + eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None, + eval_metric: Optional[Union[str, Sequence[str], Callable]] = None, early_stopping_rounds: Optional[int] = None, verbose: bool = True, xgb_model: Optional[Union[Booster, XGBModel]] = None, - sample_weight_eval_set: Optional[List[_DaskCollection]] = None, - base_margin_eval_set: Optional[List[_DaskCollection]] = None, + sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None, + base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None, feature_weights: Optional[_DaskCollection] = None, - callbacks: Optional[List[TrainingCallback]] = None + callbacks: Optional[Sequence[TrainingCallback]] = None ) -> "DaskXGBRFRegressor": _assert_dask_support() args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} @@ -2141,15 +2142,15 @@ class DaskXGBRFClassifier(DaskXGBClassifier): *, sample_weight: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None, - eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None, - eval_metric: Optional[Union[str, List[str], Metric]] = None, + eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None, + eval_metric: Optional[Union[str, Sequence[str], Callable]] = None, early_stopping_rounds: Optional[int] = None, verbose: bool = True, xgb_model: Optional[Union[Booster, XGBModel]] = None, - sample_weight_eval_set: Optional[List[_DaskCollection]] = None, - base_margin_eval_set: Optional[List[_DaskCollection]] = None, + sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None, + base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None, feature_weights: Optional[_DaskCollection] = None, - callbacks: Optional[List[TrainingCallback]] = None + callbacks: Optional[Sequence[TrainingCallback]] = None ) -> "DaskXGBRFClassifier": _assert_dask_support() args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index dd877cfc0..7313351fd 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -4,7 +4,8 @@ import copy import warnings import json import os -from typing import Union, Optional, List, Dict, Callable, Tuple, Any, TypeVar, Type +from typing import Union, Optional, List, Dict, Callable, Tuple, Any, TypeVar, Type, cast +from typing import Sequence import numpy as np from .core import Booster, DMatrix, XGBoostError @@ -36,7 +37,7 @@ class XGBRankerMixIn: # pylint: disable=too-few-public-methods def _check_rf_callback( early_stopping_rounds: Optional[int], - callbacks: Optional[List[TrainingCallback]], + callbacks: Optional[Sequence[TrainingCallback]], ) -> None: if early_stopping_rounds is not None or callbacks is not None: raise NotImplementedError( @@ -343,14 +344,14 @@ def _wrap_evaluation_matrices( sample_weight: Optional[Any], base_margin: Optional[Any], feature_weights: Optional[Any], - eval_set: Optional[List[Tuple[Any, Any]]], - sample_weight_eval_set: Optional[List[Any]], - base_margin_eval_set: Optional[List[Any]], - eval_group: Optional[List[Any]], - eval_qid: Optional[List[Any]], + eval_set: Optional[Sequence[Tuple[Any, Any]]], + sample_weight_eval_set: Optional[Sequence[Any]], + base_margin_eval_set: Optional[Sequence[Any]], + eval_group: Optional[Sequence[Any]], + eval_qid: Optional[Sequence[Any]], create_dmatrix: Callable, enable_categorical: bool, -) -> Tuple[Any, Optional[List[Tuple[Any, str]]]]: +) -> Tuple[Any, List[Tuple[Any, str]]]: """Convert array_like evaluation matrices into DMatrix. Perform validation on the way. """ @@ -368,7 +369,7 @@ def _wrap_evaluation_matrices( n_validation = 0 if eval_set is None else len(eval_set) - def validate_or_none(meta: Optional[List], name: str) -> List: + def validate_or_none(meta: Optional[Sequence], name: str) -> Sequence: if meta is None: return [None] * n_validation if len(meta) != n_validation: @@ -464,7 +465,7 @@ class XGBModel(XGBModelBase): missing: float = np.nan, num_parallel_tree: Optional[int] = None, monotone_constraints: Optional[Union[Dict[str, int], str]] = None, - interaction_constraints: Optional[Union[str, List[Tuple[str]]]] = None, + interaction_constraints: Optional[Union[str, Sequence[Sequence[str]]]] = None, importance_type: Optional[str] = None, gpu_id: Optional[int] = None, validate_parameters: Optional[bool] = None, @@ -715,7 +716,7 @@ class XGBModel(XGBModelBase): def _configure_fit( self, booster: Optional[Union[Booster, "XGBModel", str]], - eval_metric: Optional[Union[Callable, str, List[str]]], + eval_metric: Optional[Union[Callable, str, Sequence[str]]], params: Dict[str, Any], early_stopping_rounds: Optional[int], ) -> Tuple[ @@ -788,10 +789,7 @@ class XGBModel(XGBModelBase): def _set_evaluation_result(self, evals_result: TrainingCallback.EvalsLog) -> None: if evals_result: - for val in evals_result.items(): - evals_result_key = list(val[1].keys())[0] - evals_result[val[0]][evals_result_key] = val[1][evals_result_key] - self.evals_result_ = evals_result + self.evals_result_ = cast(Dict[str, Dict[str, List[float]]], evals_result) @_deprecate_positional_args def fit( @@ -801,15 +799,15 @@ class XGBModel(XGBModelBase): *, sample_weight: Optional[array_like] = None, base_margin: Optional[array_like] = None, - eval_set: Optional[List[Tuple[array_like, array_like]]] = None, - eval_metric: Optional[Union[str, List[str], Metric]] = None, + eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None, + eval_metric: Optional[Union[str, Sequence[str], Metric]] = None, early_stopping_rounds: Optional[int] = None, verbose: Optional[bool] = True, xgb_model: Optional[Union[Booster, str, "XGBModel"]] = None, - sample_weight_eval_set: Optional[List[array_like]] = None, - base_margin_eval_set: Optional[List[array_like]] = None, + sample_weight_eval_set: Optional[Sequence[array_like]] = None, + base_margin_eval_set: Optional[Sequence[array_like]] = None, feature_weights: Optional[array_like] = None, - callbacks: Optional[List[TrainingCallback]] = None + callbacks: Optional[Sequence[TrainingCallback]] = None ) -> "XGBModel": # pylint: disable=invalid-name,attribute-defined-outside-init """Fit gradient boosting model. @@ -1031,7 +1029,7 @@ class XGBModel(XGBModelBase): Input features matrix. iteration_range : - See :py:meth:`xgboost.XGBRegressor.predict`. + See :py:meth:`predict`. ntree_limit : Deprecated, use ``iteration_range`` instead. @@ -1055,40 +1053,26 @@ class XGBModel(XGBModelBase): iteration_range=iteration_range ) - def evals_result(self) -> TrainingCallback.EvalsLog: + def evals_result(self) -> Dict[str, Dict[str, List[float]]]: """Return the evaluation results. - If **eval_set** is passed to the `fit` function, you can call - ``evals_result()`` to get evaluation results for all passed **eval_sets**. - When **eval_metric** is also passed to the `fit` function, the - **evals_result** will contain the **eval_metrics** passed to the `fit` function. + If **eval_set** is passed to the :py:meth:`fit` function, you can call + ``evals_result()`` to get evaluation results for all passed **eval_sets**. When + **eval_metric** is also passed to the :py:meth:`fit` function, the + **evals_result** will contain the **eval_metrics** passed to the :py:meth:`fit` + function. - Returns - ------- - evals_result : dictionary - - Example - ------- - - .. code-block:: python - - param_dist = {'objective':'binary:logistic', 'n_estimators':2} - - clf = xgb.XGBModel(**param_dist) - - clf.fit(X_train, y_train, - eval_set=[(X_train, y_train), (X_test, y_test)], - eval_metric='logloss', - verbose=True) - - evals_result = clf.evals_result() - - The variable **evals_result** will contain: + The returned evaluation result is a dictionary: .. code-block:: python {'validation_0': {'logloss': ['0.604835', '0.531479']}, 'validation_1': {'logloss': ['0.41965', '0.17686']}} + + Returns + ------- + evals_result + """ if getattr(self, "evals_result_", None) is not None: evals_result = self.evals_result_ @@ -1193,8 +1177,8 @@ class XGBModel(XGBModelBase): .. note:: Intercept is defined only for linear learners Intercept (bias) is only defined when the linear model is chosen as base - learner (`booster=gblinear`). It is not defined for other base learner types, such - as tree learners (`booster=gbtree`). + learner (`booster=gblinear`). It is not defined for other base learner types, + such as tree learners (`booster=gbtree`). Returns ------- @@ -1251,15 +1235,15 @@ class XGBClassifier(XGBModel, XGBClassifierBase): *, sample_weight: Optional[array_like] = None, base_margin: Optional[array_like] = None, - eval_set: Optional[List[Tuple[array_like, array_like]]] = None, - eval_metric: Optional[Union[str, List[str], Metric]] = None, + eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None, + eval_metric: Optional[Union[str, Sequence[str], Metric]] = None, early_stopping_rounds: Optional[int] = None, verbose: Optional[bool] = True, xgb_model: Optional[Union[Booster, str, XGBModel]] = None, - sample_weight_eval_set: Optional[List[array_like]] = None, - base_margin_eval_set: Optional[List[array_like]] = None, + sample_weight_eval_set: Optional[Sequence[array_like]] = None, + base_margin_eval_set: Optional[Sequence[array_like]] = None, feature_weights: Optional[array_like] = None, - callbacks: Optional[List[TrainingCallback]] = None + callbacks: Optional[Sequence[TrainingCallback]] = None ) -> "XGBClassifier": # pylint: disable = attribute-defined-outside-init,too-many-statements evals_result: TrainingCallback.EvalsLog = {} @@ -1445,51 +1429,6 @@ class XGBClassifier(XGBModel, XGBClassifierBase): getattr(self, "n_classes_", None), class_probs, np.vstack ) - def evals_result(self) -> TrainingCallback.EvalsLog: - """Return the evaluation results. - - If **eval_set** is passed to the `fit` function, you can call - ``evals_result()`` to get evaluation results for all passed **eval_sets**. - - When **eval_metric** is also passed as a parameter, the **evals_result** will - contain the **eval_metric** passed to the `fit` function. - - Returns - ------- - evals_result : dictionary - - Example - ------- - - .. code-block:: python - - param_dist = { - 'objective':'binary:logistic', 'n_estimators':2, eval_metric="logloss" - } - - clf = xgb.XGBClassifier(**param_dist) - - clf.fit(X_train, y_train, - eval_set=[(X_train, y_train), (X_test, y_test)], - verbose=True) - - evals_result = clf.evals_result() - - The variable **evals_result** will contain - - .. code-block:: python - - {'validation_0': {'logloss': ['0.604835', '0.531479']}, - 'validation_1': {'logloss': ['0.41965', '0.17686']}} - - """ - if self.evals_result_: - evals_result = self.evals_result_ - else: - raise XGBoostError('No results.') - - return evals_result - @xgboost_model_doc( "scikit-learn API for XGBoost random forest classification.", @@ -1533,15 +1472,15 @@ class XGBRFClassifier(XGBClassifier): *, sample_weight: Optional[array_like] = None, base_margin: Optional[array_like] = None, - eval_set: Optional[List[Tuple[array_like, array_like]]] = None, - eval_metric: Optional[Union[str, List[str], Metric]] = None, + eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None, + eval_metric: Optional[Union[str, Sequence[str], Metric]] = None, early_stopping_rounds: Optional[int] = None, verbose: Optional[bool] = True, xgb_model: Optional[Union[Booster, str, XGBModel]] = None, - sample_weight_eval_set: Optional[List[array_like]] = None, - base_margin_eval_set: Optional[List[array_like]] = None, + sample_weight_eval_set: Optional[Sequence[array_like]] = None, + base_margin_eval_set: Optional[Sequence[array_like]] = None, feature_weights: Optional[array_like] = None, - callbacks: Optional[List[TrainingCallback]] = None + callbacks: Optional[Sequence[TrainingCallback]] = None ) -> "XGBRFClassifier": args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} _check_rf_callback(early_stopping_rounds, callbacks) @@ -1605,15 +1544,15 @@ class XGBRFRegressor(XGBRegressor): *, sample_weight: Optional[array_like] = None, base_margin: Optional[array_like] = None, - eval_set: Optional[List[Tuple[array_like, array_like]]] = None, - eval_metric: Optional[Union[str, List[str], Metric]] = None, + eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None, + eval_metric: Optional[Union[str, Sequence[str], Metric]] = None, early_stopping_rounds: Optional[int] = None, verbose: Optional[bool] = True, xgb_model: Optional[Union[Booster, str, XGBModel]] = None, - sample_weight_eval_set: Optional[List[array_like]] = None, - base_margin_eval_set: Optional[List[array_like]] = None, + sample_weight_eval_set: Optional[Sequence[array_like]] = None, + base_margin_eval_set: Optional[Sequence[array_like]] = None, feature_weights: Optional[array_like] = None, - callbacks: Optional[List[TrainingCallback]] = None + callbacks: Optional[Sequence[TrainingCallback]] = None ) -> "XGBRFRegressor": args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} _check_rf_callback(early_stopping_rounds, callbacks) @@ -1682,17 +1621,17 @@ class XGBRanker(XGBModel, XGBRankerMixIn): qid: Optional[array_like] = None, sample_weight: Optional[array_like] = None, base_margin: Optional[array_like] = None, - eval_set: Optional[List[Tuple[array_like, array_like]]] = None, - eval_group: Optional[List[array_like]] = None, - eval_qid: Optional[List[array_like]] = None, - eval_metric: Optional[Union[str, List[str], Metric]] = None, + eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None, + eval_group: Optional[Sequence[array_like]] = None, + eval_qid: Optional[Sequence[array_like]] = None, + eval_metric: Optional[Union[str, Sequence[str], Metric]] = None, early_stopping_rounds: Optional[int] = None, verbose: Optional[bool] = False, xgb_model: Optional[Union[Booster, str, XGBModel]] = None, - sample_weight_eval_set: Optional[List[array_like]] = None, - base_margin_eval_set: Optional[List[array_like]] = None, + sample_weight_eval_set: Optional[Sequence[array_like]] = None, + base_margin_eval_set: Optional[Sequence[array_like]] = None, feature_weights: Optional[array_like] = None, - callbacks: Optional[List[TrainingCallback]] = None + callbacks: Optional[Sequence[TrainingCallback]] = None ) -> "XGBRanker": # pylint: disable = attribute-defined-outside-init,arguments-differ """Fit gradient boosting ranker diff --git a/python-package/xgboost/tracker.py b/python-package/xgboost/tracker.py index 61e3a1a06..06051c574 100644 --- a/python-package/xgboost/tracker.py +++ b/python-package/xgboost/tracker.py @@ -13,7 +13,7 @@ from threading import Thread import argparse import sys -from typing import Dict, List, Tuple, Union, Optional +from typing import Dict, List, Tuple, Union, Optional _RingMap = Dict[int, Tuple[int, int]] _TreeMap = Dict[int, List[int]] diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index 2b0035a9a..9127b8e6f 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -3,18 +3,20 @@ # pylint: disable=too-many-branches, too-many-statements """Training Library containing training routines.""" import copy -from typing import Optional, List +import os import warnings +from typing import Optional, Dict, Any, Union, Tuple, cast, Sequence import numpy as np -from .core import Booster, XGBoostError, _get_booster_layer_trees -from .core import _deprecate_positional_args -from .core import Objective, Metric +from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees +from .core import Metric, Objective from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold) from . import callback -def _assert_new_callback(callbacks: Optional[List[callback.TrainingCallback]]) -> None: +def _assert_new_callback( + callbacks: Optional[Sequence[callback.TrainingCallback]] +) -> None: is_new_callback: bool = not callbacks or all( isinstance(c, callback.TrainingCallback) for c in callbacks ) @@ -44,24 +46,24 @@ def _configure_custom_metric( def _train_internal( - params, - dtrain, - num_boost_round=10, - evals=(), - obj=None, - feval=None, - custom_metric=None, - xgb_model=None, - callbacks=None, - evals_result=None, - maximize=None, - verbose_eval=None, - early_stopping_rounds=None, -): + params: Dict[str, Any], + dtrain: DMatrix, + num_boost_round: int = 10, + evals: Optional[Sequence[Tuple[DMatrix, str]]] = None, + obj: Optional[Objective] = None, + feval: Optional[Metric] = None, + custom_metric: Optional[Metric] = None, + xgb_model: Optional[Union[str, os.PathLike, Booster, bytearray]] = None, + callbacks: Optional[Sequence[callback.TrainingCallback]] = None, + evals_result: callback.TrainingCallback.EvalsLog = None, + maximize: Optional[bool] = None, + verbose_eval: Optional[Union[bool, int]] = True, + early_stopping_rounds: Optional[int] = None, +) -> Booster: """internal training function""" - callbacks = [] if callbacks is None else copy.copy(callbacks) + callbacks = [] if callbacks is None else copy.copy(list(callbacks)) metric_fn = _configure_custom_metric(feval, custom_metric) - evals = list(evals) + evals = list(evals) if evals else [] bst = Booster(params, [dtrain] + [d[0] for d in evals]) @@ -78,7 +80,7 @@ def _train_internal( callbacks.append( callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize) ) - callbacks = callback.CallbackContainer( + cb_container = callback.CallbackContainer( callbacks, metric=metric_fn, # For old `feval` parameter, the behavior is unchanged. For the new @@ -87,32 +89,32 @@ def _train_internal( output_margin=callable(obj) or metric_fn is feval, ) - bst = callbacks.before_training(bst) + bst = cb_container.before_training(bst) for i in range(start_iteration, num_boost_round): - if callbacks.before_iteration(bst, i, dtrain, evals): + if cb_container.before_iteration(bst, i, dtrain, evals): break bst.update(dtrain, i, obj) - if callbacks.after_iteration(bst, i, dtrain, evals): + if cb_container.after_iteration(bst, i, dtrain, evals): break - bst = callbacks.after_training(bst) + bst = cb_container.after_training(bst) if evals_result is not None: - evals_result.update(callbacks.history) + evals_result.update(cb_container.history) # These should be moved into callback functions `after_training`, but until old # callbacks are removed, the train function is the only place for setting the # attributes. num_parallel_tree, _ = _get_booster_layer_trees(bst) if bst.attr('best_score') is not None: - bst.best_score = float(bst.attr('best_score')) - bst.best_iteration = int(bst.attr('best_iteration')) + bst.best_score = float(cast(str, bst.attr('best_score'))) + bst.best_iteration = int(cast(str, bst.attr('best_iteration'))) # num_class is handled internally bst.set_attr( best_ntree_limit=str((bst.best_iteration + 1) * num_parallel_tree) ) - bst.best_ntree_limit = int(bst.attr("best_ntree_limit")) + bst.best_ntree_limit = int(cast(str, bst.attr("best_ntree_limit"))) else: # Due to compatibility with version older than 1.4, these attributes are added # to Python object even if early stopping is not used. @@ -126,35 +128,32 @@ def _train_internal( return bst.copy() -@_deprecate_positional_args def train( - params, - dtrain, - num_boost_round=10, - *, - evals=(), + params: Dict[str, Any], + dtrain: DMatrix, + num_boost_round: int = 10, + evals: Optional[Sequence[Tuple[DMatrix, str]]] = None, obj: Optional[Objective] = None, - feval=None, - maximize=None, - early_stopping_rounds=None, - evals_result=None, - verbose_eval=True, - xgb_model=None, - callbacks=None, + feval: Optional[Metric] = None, + maximize: Optional[bool] = None, + early_stopping_rounds: Optional[int] = None, + evals_result: callback.TrainingCallback.EvalsLog = None, + verbose_eval: Optional[Union[bool, int]] = True, + xgb_model: Optional[Union[str, os.PathLike, Booster, bytearray]] = None, + callbacks: Optional[Sequence[callback.TrainingCallback]] = None, custom_metric: Optional[Metric] = None, -): - # pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init +) -> Booster: """Train a booster with given parameters. Parameters ---------- - params : dict + params : Booster params. - dtrain : DMatrix + dtrain : Data to be trained. - num_boost_round: int + num_boost_round : Number of boosting iterations. - evals: list of pairs (DMatrix, string) + evals : List of validation sets for which metrics will evaluated during training. Validation metrics will help us track the performance of the model. obj @@ -166,7 +165,7 @@ def train( Use `custom_metric` instead. maximize : bool Whether to maximize feval. - early_stopping_rounds: int + early_stopping_rounds : Activates early stopping. Validation metric needs to improve at least once in every **early_stopping_rounds** round(s) to continue training. Requires at least one item in **evals**. @@ -178,7 +177,7 @@ def train( **params**, the last metric will be used for early stopping. If early stopping occurs, the model will have two additional fields: ``bst.best_score``, ``bst.best_iteration``. - evals_result: dict + evals_result : This dictionary stores the evaluation results of all the items in watchlist. Example: with a watchlist containing @@ -191,7 +190,7 @@ def train( {'train': {'logloss': ['0.48253', '0.35953']}, 'eval': {'logloss': ['0.480385', '0.357756']}} - verbose_eval : bool or int + verbose_eval : Requires at least one item in **evals**. If **verbose_eval** is True then the evaluation metric on the validation set is printed at each boosting stage. @@ -200,9 +199,9 @@ def train( / the boosting stage found by using **early_stopping_rounds** is also printed. Example: with ``verbose_eval=4`` and at least one item in **evals**, an evaluation metric is printed every 4 boosting stages, instead of every boosting stage. - xgb_model : file name of stored xgb model or 'Booster' instance + xgb_model : Xgb model to be loaded before training (allows training continuation). - callbacks : list of callback functions + callbacks : List of callback functions that are applied at end of each iteration. It is possible to use predefined callbacks by using :ref:`Callback API `. diff --git a/tests/python/test_interaction_constraints.py b/tests/python/test_interaction_constraints.py index c582614c8..12653b538 100644 --- a/tests/python/test_interaction_constraints.py +++ b/tests/python/test_interaction_constraints.py @@ -9,7 +9,9 @@ rng = np.random.RandomState(1994) class TestInteractionConstraints: - def run_interaction_constraints(self, tree_method, feature_names=None, interaction_constraints='[[0, 1]]'): + def run_interaction_constraints( + self, tree_method, feature_names=None, interaction_constraints='[[0, 1]]' + ): x1 = np.random.normal(loc=1.0, scale=1.0, size=1000) x2 = np.random.normal(loc=1.0, scale=1.0, size=1000) x3 = np.random.choice([1, 2, 3], size=1000, replace=True)