Typehint for subset of core API. (#7348)

This commit is contained in:
Jiaming Yuan 2021-10-28 20:47:04 +08:00 committed by GitHub
parent 45aef75cca
commit c6769488b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 400 additions and 412 deletions

View File

@ -6,17 +6,18 @@ from abc import ABC
import collections import collections
import os import os
import pickle import pickle
from typing import Callable, List, Optional, Union, Dict, Tuple from typing import Callable, List, Optional, Union, Dict, Tuple, TypeVar, cast
from typing import Sequence
import numpy import numpy
from . import rabit from . import rabit
from .core import Booster, XGBoostError from .core import Booster, DMatrix, XGBoostError
from .compat import STRING_TYPES from .compat import STRING_TYPES
# The new implementation of callback functions. _Score = Union[float, Tuple[float, float]]
# Breaking: _ScoreList = Union[List[float], List[Tuple[float, float]]]
# - reset learning rate no longer accepts total boosting rounds
# pylint: disable=unused-argument # pylint: disable=unused-argument
class TrainingCallback(ABC): class TrainingCallback(ABC):
@ -26,9 +27,9 @@ class TrainingCallback(ABC):
''' '''
EvalsLog = Dict[str, Dict[str, Union[List[float], List[Tuple[float, float]]]]] EvalsLog = Dict[str, Dict[str, _ScoreList]]
def __init__(self): def __init__(self) -> None:
pass pass
def before_training(self, model): def before_training(self, model):
@ -48,18 +49,18 @@ class TrainingCallback(ABC):
return False return False
def _aggcv(rlist): def _aggcv(rlist: List[str]) -> List[Tuple[str, float, float]]:
# pylint: disable=invalid-name # pylint: disable=invalid-name, too-many-locals
"""Aggregate cross-validation results. """Aggregate cross-validation results.
""" """
cvmap = {} cvmap: Dict[Tuple[int, str], List[float]] = {}
idx = rlist[0].split()[0] idx = rlist[0].split()[0]
for line in rlist: for line in rlist:
arr = line.split() arr: List[str] = line.split()
assert idx == arr[0] assert idx == arr[0]
for metric_idx, it in enumerate(arr[1:]): for metric_idx, it in enumerate(arr[1:]):
if not isinstance(it, STRING_TYPES): if not isinstance(it, str):
it = it.decode() it = it.decode()
k, v = it.split(':') k, v = it.split(':')
if (metric_idx, k) not in cvmap: if (metric_idx, k) not in cvmap:
@ -67,16 +68,20 @@ def _aggcv(rlist):
cvmap[(metric_idx, k)].append(float(v)) cvmap[(metric_idx, k)].append(float(v))
msg = idx msg = idx
results = [] results = []
for (metric_idx, k), v in sorted(cvmap.items(), key=lambda x: x[0][0]): for (_, name), s in sorted(cvmap.items(), key=lambda x: x[0][0]):
v = numpy.array(v) as_arr = numpy.array(s)
if not isinstance(msg, STRING_TYPES): if not isinstance(msg, STRING_TYPES):
msg = msg.decode() msg = msg.decode()
mean, std = numpy.mean(v), numpy.std(v) mean, std = numpy.mean(as_arr), numpy.std(as_arr)
results.extend([(k, mean, std)]) results.extend([(name, mean, std)])
return results return results
def _allreduce_metric(score): # allreduce type
_ART = TypeVar("_ART")
def _allreduce_metric(score: _ART) -> _ART:
'''Helper function for computing customized metric in distributed '''Helper function for computing customized metric in distributed
environment. Not strictly correct as many functions don't use mean value environment. Not strictly correct as many functions don't use mean value
as final result. as final result.
@ -89,13 +94,13 @@ def _allreduce_metric(score):
if isinstance(score, tuple): # has mean and stdv if isinstance(score, tuple): # has mean and stdv
raise ValueError( raise ValueError(
'xgboost.cv function should not be used in distributed environment.') 'xgboost.cv function should not be used in distributed environment.')
score = numpy.array([score]) arr = numpy.array([score])
score = rabit.allreduce(score, rabit.Op.SUM) / world arr = rabit.allreduce(arr, rabit.Op.SUM) / world
return score[0] return arr[0]
class CallbackContainer: class CallbackContainer:
'''A special callback for invoking a list of other callbacks. '''A special internal callback for invoking a list of other callbacks.
.. versionadded:: 1.3.0 .. versionadded:: 1.3.0
@ -105,7 +110,7 @@ class CallbackContainer:
def __init__( def __init__(
self, self,
callbacks: List[TrainingCallback], callbacks: Sequence[TrainingCallback],
metric: Callable = None, metric: Callable = None,
output_margin: bool = True, output_margin: bool = True,
is_cv: bool = False is_cv: bool = False
@ -146,33 +151,50 @@ class CallbackContainer:
assert isinstance(model, Booster), msg assert isinstance(model, Booster), msg
return model return model
def before_iteration(self, model, epoch, dtrain, evals) -> bool: def before_iteration(
self, model, epoch: int, dtrain: DMatrix, evals: List[Tuple[DMatrix, str]]
) -> bool:
'''Function called before training iteration.''' '''Function called before training iteration.'''
return any(c.before_iteration(model, epoch, self.history) return any(c.before_iteration(model, epoch, self.history)
for c in self.callbacks) for c in self.callbacks)
def _update_history(self, score, epoch): def _update_history(
self,
score: Union[List[Tuple[str, float]], List[Tuple[str, float, float]]],
epoch: int
) -> None:
for d in score: for d in score:
name, s = d[0], float(d[1]) name: str = d[0]
s: float = d[1]
if self.is_cv: if self.is_cv:
std = float(d[2]) std = float(cast(Tuple[str, float, float], d)[2])
s = (s, std) x: _Score = (s, std)
else:
x = s
splited_names = name.split('-') splited_names = name.split('-')
data_name = splited_names[0] data_name = splited_names[0]
metric_name = '-'.join(splited_names[1:]) metric_name = '-'.join(splited_names[1:])
s = _allreduce_metric(s) x = _allreduce_metric(x)
if data_name in self.history: if data_name not in self.history:
data_history = self.history[data_name]
if metric_name in data_history:
data_history[metric_name].append(s)
else:
data_history[metric_name] = [s]
else:
self.history[data_name] = collections.OrderedDict() self.history[data_name] = collections.OrderedDict()
self.history[data_name][metric_name] = [s] data_history = self.history[data_name]
return False if metric_name not in data_history:
data_history[metric_name] = cast(_ScoreList, [])
metric_history = data_history[metric_name]
if self.is_cv:
cast(List[Tuple[float, float]], metric_history).append(
cast(Tuple[float, float], x)
)
else:
cast(List[float], metric_history).append(cast(float, x))
def after_iteration(self, model, epoch, dtrain, evals) -> bool: def after_iteration(
self,
model,
epoch: int,
dtrain: DMatrix,
evals: Optional[List[Tuple[DMatrix, str]]],
) -> bool:
'''Function called after training iteration.''' '''Function called after training iteration.'''
if self.is_cv: if self.is_cv:
scores = model.eval(epoch, self.metric, self._output_margin) scores = model.eval(epoch, self.metric, self._output_margin)
@ -183,18 +205,20 @@ class CallbackContainer:
evals = [] if evals is None else evals evals = [] if evals is None else evals
for _, name in evals: for _, name in evals:
assert name.find('-') == -1, 'Dataset name should not contain `-`' assert name.find('-') == -1, 'Dataset name should not contain `-`'
score = model.eval_set(evals, epoch, self.metric, self._output_margin) score: str = model.eval_set(evals, epoch, self.metric, self._output_margin)
score = score.split()[1:] # into datasets splited = score.split()[1:] # into datasets
# split up `test-error:0.1234` # split up `test-error:0.1234`
score = [tuple(s.split(':')) for s in score] metric_score_str = [tuple(s.split(':')) for s in splited]
self._update_history(score, epoch) # convert to float
metric_score = [(n, float(s)) for n, s in metric_score_str]
self._update_history(metric_score, epoch)
ret = any(c.after_iteration(model, epoch, self.history) ret = any(c.after_iteration(model, epoch, self.history)
for c in self.callbacks) for c in self.callbacks)
return ret return ret
class LearningRateScheduler(TrainingCallback): class LearningRateScheduler(TrainingCallback):
'''Callback function for scheduling learning rate. """Callback function for scheduling learning rate.
.. versionadded:: 1.3.0 .. versionadded:: 1.3.0
@ -207,18 +231,24 @@ class LearningRateScheduler(TrainingCallback):
should be a sequence like list or tuple with the same size of boosting should be a sequence like list or tuple with the same size of boosting
rounds. rounds.
''' """
def __init__(self, learning_rates) -> None:
assert callable(learning_rates) or \ def __init__(
isinstance(learning_rates, collections.abc.Sequence) self, learning_rates: Union[Callable[[int], float], Sequence[float]]
) -> None:
assert callable(learning_rates) or isinstance(
learning_rates, collections.abc.Sequence
)
if callable(learning_rates): if callable(learning_rates):
self.learning_rates = learning_rates self.learning_rates = learning_rates
else: else:
self.learning_rates = lambda epoch: learning_rates[epoch] self.learning_rates = lambda epoch: cast(Sequence, learning_rates)[epoch]
super().__init__() super().__init__()
def after_iteration(self, model, epoch, evals_log) -> bool: def after_iteration(
model.set_param('learning_rate', self.learning_rates(epoch)) self, model, epoch: int, evals_log: TrainingCallback.EvalsLog
) -> bool:
model.set_param("learning_rate", self.learning_rates(epoch))
return False return False
@ -230,17 +260,17 @@ class EarlyStopping(TrainingCallback):
Parameters Parameters
---------- ----------
rounds rounds :
Early stopping rounds. Early stopping rounds.
metric_name metric_name :
Name of metric that is used for early stopping. Name of metric that is used for early stopping.
data_name data_name :
Name of dataset that is used for early stopping. Name of dataset that is used for early stopping.
maximize maximize :
Whether to maximize evaluation metric. None means auto (discouraged). Whether to maximize evaluation metric. None means auto (discouraged).
save_best save_best :
Whether training should return the best model or the last model. Whether training should return the best model or the last model.
min_delta min_delta :
Minimum absolute change in score to be qualified as an improvement. Minimum absolute change in score to be qualified as an improvement.
.. versionadded:: 1.5.0 .. versionadded:: 1.5.0
@ -279,8 +309,6 @@ class EarlyStopping(TrainingCallback):
if self._min_delta < 0: if self._min_delta < 0:
raise ValueError("min_delta must be greater or equal to 0.") raise ValueError("min_delta must be greater or equal to 0.")
self.improve_op = None
self.current_rounds: int = 0 self.current_rounds: int = 0
self.best_scores: dict = {} self.best_scores: dict = {}
self.starting_round: int = 0 self.starting_round: int = 0
@ -290,16 +318,18 @@ class EarlyStopping(TrainingCallback):
self.starting_round = model.num_boosted_rounds() self.starting_round = model.num_boosted_rounds()
return model return model
def _update_rounds(self, score, name, metric, model, epoch) -> bool: def _update_rounds(
def get_s(x): self, score: _Score, name: str, metric: str, model, epoch: int
) -> bool:
def get_s(x: _Score) -> float:
"""get score if it's cross validation history.""" """get score if it's cross validation history."""
return x[0] if isinstance(x, tuple) else x return x[0] if isinstance(x, tuple) else x
def maximize(new, best): def maximize(new: _Score, best: _Score) -> bool:
"""New score should be greater than the old one.""" """New score should be greater than the old one."""
return numpy.greater(get_s(new) - self._min_delta, get_s(best)) return numpy.greater(get_s(new) - self._min_delta, get_s(best))
def minimize(new, best): def minimize(new: _Score, best: _Score) -> bool:
"""New score should be smaller than the old one.""" """New score should be smaller than the old one."""
return numpy.greater(get_s(best) - self._min_delta, get_s(new)) return numpy.greater(get_s(best) - self._min_delta, get_s(new))
@ -314,25 +344,25 @@ class EarlyStopping(TrainingCallback):
self.maximize = False self.maximize = False
if self.maximize: if self.maximize:
self.improve_op = maximize improve_op = maximize
else: else:
self.improve_op = minimize improve_op = minimize
assert self.improve_op assert improve_op
if not self.stopping_history: # First round if not self.stopping_history: # First round
self.current_rounds = 0 self.current_rounds = 0
self.stopping_history[name] = {} self.stopping_history[name] = {}
self.stopping_history[name][metric] = [score] self.stopping_history[name][metric] = cast(_ScoreList, [score])
self.best_scores[name] = {} self.best_scores[name] = {}
self.best_scores[name][metric] = [score] self.best_scores[name][metric] = [score]
model.set_attr(best_score=str(score), best_iteration=str(epoch)) model.set_attr(best_score=str(score), best_iteration=str(epoch))
elif not self.improve_op(score, self.best_scores[name][metric][-1]): elif not improve_op(score, self.best_scores[name][metric][-1]):
# Not improved # Not improved
self.stopping_history[name][metric].append(score) self.stopping_history[name][metric].append(score) # type: ignore
self.current_rounds += 1 self.current_rounds += 1
else: # Improved else: # Improved
self.stopping_history[name][metric].append(score) self.stopping_history[name][metric].append(score) # type: ignore
self.best_scores[name][metric].append(score) self.best_scores[name][metric].append(score)
record = self.stopping_history[name][metric][-1] record = self.stopping_history[name][metric][-1]
model.set_attr(best_score=str(record), best_iteration=str(epoch)) model.set_attr(best_score=str(record), best_iteration=str(epoch))
@ -390,16 +420,16 @@ class EvaluationMonitor(TrainingCallback):
Parameters Parameters
---------- ----------
metric : callable metric :
Extra user defined metric. Extra user defined metric.
rank : int rank :
Which worker should be used for printing the result. Which worker should be used for printing the result.
period : int period :
How many epoches between printing. How many epoches between printing.
show_stdv : bool show_stdv :
Used in cv to show standard deviation. Users should not specify it. Used in cv to show standard deviation. Users should not specify it.
''' '''
def __init__(self, rank=0, period=1, show_stdv=False) -> None: def __init__(self, rank: int = 0, period: int = 1, show_stdv: bool = False) -> None:
self.printer_rank = rank self.printer_rank = rank
self.show_stdv = show_stdv self.show_stdv = show_stdv
self.period = period self.period = period
@ -457,22 +487,27 @@ class TrainingCheckPoint(TrainingCallback):
Parameters Parameters
---------- ----------
directory : os.PathLike directory :
Output model directory. Output model directory.
name : str name :
pattern of output model file. Models will be saved as name_0.json, name_1.json, pattern of output model file. Models will be saved as name_0.json, name_1.json,
name_2.json .... name_2.json ....
as_pickle : boolean as_pickle :
When set to Ture, all training parameters will be saved in pickle format, instead When set to Ture, all training parameters will be saved in pickle format, instead
of saving only the model. of saving only the model.
iterations : int iterations :
Interval of checkpointing. Checkpointing is slow so setting a larger number can Interval of checkpointing. Checkpointing is slow so setting a larger number can
reduce performance hit. reduce performance hit.
''' '''
def __init__(self, directory: os.PathLike, name: str = 'model', def __init__(
as_pickle=False, iterations: int = 100): self,
self._path = directory directory: Union[str, os.PathLike],
name: str = 'model',
as_pickle: bool = False,
iterations: int = 100
) -> None:
self._path = os.fspath(directory)
self._name = name self._name = name
self._as_pickle = as_pickle self._as_pickle = as_pickle
self._iterations = iterations self._iterations = iterations

View File

@ -6,7 +6,7 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import List, Optional, Any, Union, Dict, TypeVar from typing import List, Optional, Any, Union, Dict, TypeVar
# pylint: enable=no-name-in-module,import-error # pylint: enable=no-name-in-module,import-error
from typing import Callable, Tuple, cast from typing import Callable, Tuple, cast, Sequence
import ctypes import ctypes
import os import os
import re import re
@ -31,20 +31,6 @@ class XGBoostError(ValueError):
"""Error thrown by xgboost trainer.""" """Error thrown by xgboost trainer."""
class EarlyStopException(Exception):
"""Exception to signal early stopping.
Parameters
----------
best_iteration : int
The best iteration stopped.
"""
def __init__(self, best_iteration):
super().__init__()
self.best_iteration = best_iteration
def from_pystr_to_cstr(data: Union[str, List[str]]): def from_pystr_to_cstr(data: Union[str, List[str]]):
"""Convert a Python str or list of Python str to C pointer """Convert a Python str or list of Python str to C pointer
@ -132,18 +118,19 @@ def _log_callback(msg: bytes) -> None:
print(py_str(msg)) print(py_str(msg))
def _get_log_callback_func(): def _get_log_callback_func() -> Callable:
"""Wrap log_callback() method in ctypes callback type""" """Wrap log_callback() method in ctypes callback type"""
# pylint: disable=invalid-name # pylint: disable=invalid-name
CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_char_p) CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_char_p)
return CALLBACK(_log_callback) return CALLBACK(_log_callback)
def _load_lib(): def _load_lib() -> ctypes.CDLL:
"""Load xgboost Library.""" """Load xgboost Library."""
lib_paths = find_lib_path() lib_paths = find_lib_path()
if not lib_paths: if not lib_paths:
return None # This happens only when building document.
return None # type: ignore
try: try:
pathBackup = os.environ['PATH'].split(os.pathsep) pathBackup = os.environ['PATH'].split(os.pathsep)
except KeyError: except KeyError:
@ -190,7 +177,7 @@ Error message(s): {os_error_list}
_LIB = _load_lib() _LIB = _load_lib()
def _check_call(ret): def _check_call(ret: int) -> None:
"""Check the return value of C API call """Check the return value of C API call
This function will raise exception when error occurs. This function will raise exception when error occurs.
@ -234,7 +221,7 @@ def _cuda_array_interface(data) -> bytes:
return interface_str return interface_str
def ctypes2numpy(cptr, length, dtype): def ctypes2numpy(cptr, length, dtype) -> np.ndarray:
"""Convert a ctypes pointer array to a numpy array.""" """Convert a ctypes pointer array to a numpy array."""
ctype = _numpy2ctypes_type(dtype) ctype = _numpy2ctypes_type(dtype)
if not isinstance(cptr, ctypes.POINTER(ctype)): if not isinstance(cptr, ctypes.POINTER(ctype)):
@ -271,7 +258,7 @@ def ctypes2cupy(cptr, length, dtype):
return arr return arr
def ctypes2buffer(cptr, length): def ctypes2buffer(cptr, length) -> bytearray:
"""Convert ctypes pointer to buffer type.""" """Convert ctypes pointer to buffer type."""
if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)): if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)):
raise RuntimeError('expected char pointer') raise RuntimeError('expected char pointer')
@ -428,7 +415,7 @@ class DataIter: # pylint: disable=too-many-instance-attributes
Parameters Parameters
---------- ----------
data_handle: input_data:
A function with same data fields like `data`, `label` with A function with same data fields like `data`, `label` with
`xgboost.DMatrix`. `xgboost.DMatrix`.
@ -627,7 +614,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
if feature_types is not None: if feature_types is not None:
self.feature_types = feature_types self.feature_types = feature_types
def _init_from_iter(self, iterator: DataIter, enable_categorical: bool): def _init_from_iter(self, iterator: DataIter, enable_categorical: bool) -> None:
it = iterator it = iterator
args = { args = {
"missing": self.missing, "missing": self.missing,
@ -654,7 +641,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
_check_call(ret) _check_call(ret)
self.handle = handle self.handle = handle
def __del__(self): def __del__(self) -> None:
if hasattr(self, "handle") and self.handle: if hasattr(self, "handle") and self.handle:
_check_call(_LIB.XGDMatrixFree(self.handle)) _check_call(_LIB.XGDMatrixFree(self.handle))
self.handle = None self.handle = None
@ -699,7 +686,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
dispatch_meta_backend(matrix=self, data=feature_weights, dispatch_meta_backend(matrix=self, data=feature_weights,
name='feature_weights') name='feature_weights')
def get_float_info(self, field): def get_float_info(self, field: str) -> np.ndarray:
"""Get float property from the DMatrix. """Get float property from the DMatrix.
Parameters Parameters
@ -720,7 +707,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
ctypes.byref(ret))) ctypes.byref(ret)))
return ctypes2numpy(ret, length.value, np.float32) return ctypes2numpy(ret, length.value, np.float32)
def get_uint_info(self, field): def get_uint_info(self, field: str) -> np.ndarray:
"""Get unsigned integer property from the DMatrix. """Get unsigned integer property from the DMatrix.
Parameters Parameters
@ -741,7 +728,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
ctypes.byref(ret))) ctypes.byref(ret)))
return ctypes2numpy(ret, length.value, np.uint32) return ctypes2numpy(ret, length.value, np.uint32)
def set_float_info(self, field, data): def set_float_info(self, field: str, data) -> None:
"""Set float type property into the DMatrix. """Set float type property into the DMatrix.
Parameters Parameters
@ -755,7 +742,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
from .data import dispatch_meta_backend from .data import dispatch_meta_backend
dispatch_meta_backend(self, data, field, 'float') dispatch_meta_backend(self, data, field, 'float')
def set_float_info_npy2d(self, field, data): def set_float_info_npy2d(self, field: str, data) -> None:
"""Set float type property into the DMatrix """Set float type property into the DMatrix
for numpy 2d array input for numpy 2d array input
@ -770,7 +757,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
from .data import dispatch_meta_backend from .data import dispatch_meta_backend
dispatch_meta_backend(self, data, field, 'float') dispatch_meta_backend(self, data, field, 'float')
def set_uint_info(self, field, data): def set_uint_info(self, field: str, data) -> None:
"""Set uint type property into the DMatrix. """Set uint type property into the DMatrix.
Parameters Parameters
@ -784,7 +771,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
from .data import dispatch_meta_backend from .data import dispatch_meta_backend
dispatch_meta_backend(self, data, field, 'uint32') dispatch_meta_backend(self, data, field, 'uint32')
def save_binary(self, fname, silent=True): def save_binary(self, fname, silent=True) -> None:
"""Save DMatrix to an XGBoost buffer. Saved binary can be later loaded """Save DMatrix to an XGBoost buffer. Saved binary can be later loaded
by providing the path to :py:func:`xgboost.DMatrix` as input. by providing the path to :py:func:`xgboost.DMatrix` as input.
@ -800,7 +787,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
c_str(fname), c_str(fname),
ctypes.c_int(silent))) ctypes.c_int(silent)))
def set_label(self, label): def set_label(self, label) -> None:
"""Set label of dmatrix """Set label of dmatrix
Parameters Parameters
@ -811,7 +798,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
from .data import dispatch_meta_backend from .data import dispatch_meta_backend
dispatch_meta_backend(self, label, 'label', 'float') dispatch_meta_backend(self, label, 'label', 'float')
def set_weight(self, weight): def set_weight(self, weight) -> None:
"""Set weight of each instance. """Set weight of each instance.
Parameters Parameters
@ -830,7 +817,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
from .data import dispatch_meta_backend from .data import dispatch_meta_backend
dispatch_meta_backend(self, weight, 'weight', 'float') dispatch_meta_backend(self, weight, 'weight', 'float')
def set_base_margin(self, margin): def set_base_margin(self, margin) -> None:
"""Set base margin of booster to start from. """Set base margin of booster to start from.
This can be used to specify a prediction value of existing model to be This can be used to specify a prediction value of existing model to be
@ -847,7 +834,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
from .data import dispatch_meta_backend from .data import dispatch_meta_backend
dispatch_meta_backend(self, margin, 'base_margin', 'float') dispatch_meta_backend(self, margin, 'base_margin', 'float')
def set_group(self, group): def set_group(self, group) -> None:
"""Set group size of DMatrix (used for ranking). """Set group size of DMatrix (used for ranking).
Parameters Parameters
@ -858,7 +845,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
from .data import dispatch_meta_backend from .data import dispatch_meta_backend
dispatch_meta_backend(self, group, 'group', 'uint32') dispatch_meta_backend(self, group, 'group', 'uint32')
def get_label(self): def get_label(self) -> np.ndarray:
"""Get the label of the DMatrix. """Get the label of the DMatrix.
Returns Returns
@ -867,7 +854,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
""" """
return self.get_float_info('label') return self.get_float_info('label')
def get_weight(self): def get_weight(self) -> np.ndarray:
"""Get the weight of the DMatrix. """Get the weight of the DMatrix.
Returns Returns
@ -876,7 +863,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
""" """
return self.get_float_info('weight') return self.get_float_info('weight')
def get_base_margin(self): def get_base_margin(self) -> np.ndarray:
"""Get the base margin of the DMatrix. """Get the base margin of the DMatrix.
Returns Returns
@ -885,7 +872,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
""" """
return self.get_float_info('base_margin') return self.get_float_info('base_margin')
def num_row(self): def num_row(self) -> int:
"""Get the number of rows in the DMatrix. """Get the number of rows in the DMatrix.
Returns Returns
@ -897,7 +884,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
ctypes.byref(ret))) ctypes.byref(ret)))
return ret.value return ret.value
def num_col(self): def num_col(self) -> int:
"""Get the number of columns (features) in the DMatrix. """Get the number of columns (features) in the DMatrix.
Returns Returns
@ -1191,7 +1178,7 @@ class DeviceQuantileDMatrix(DMatrix):
enable_categorical=enable_categorical, enable_categorical=enable_categorical,
) )
def _init(self, data, enable_categorical, **meta): def _init(self, data, enable_categorical: bool, **meta) -> None:
from .data import ( from .data import (
_is_dlpack, _is_dlpack,
_transform_dlpack, _transform_dlpack,
@ -1265,7 +1252,7 @@ def _get_booster_layer_trees(model: "Booster") -> Tuple[int, int]:
return num_parallel_tree, num_groups return num_parallel_tree, num_groups
class Booster(object): class Booster:
# pylint: disable=too-many-public-methods # pylint: disable=too-many-public-methods
"""A Booster of XGBoost. """A Booster of XGBoost.
@ -1273,7 +1260,12 @@ class Booster(object):
training, prediction and evaluation. training, prediction and evaluation.
""" """
def __init__(self, params=None, cache=(), model_file=None): def __init__(
self,
params: Optional[Dict] = None,
cache: Optional[Sequence[DMatrix]] = None,
model_file: Optional[Union["Booster", bytearray, os.PathLike, str]] = None
) -> None:
# pylint: disable=invalid-name # pylint: disable=invalid-name
""" """
Parameters Parameters
@ -1285,12 +1277,13 @@ class Booster(object):
model_file : string/os.PathLike/Booster/bytearray model_file : string/os.PathLike/Booster/bytearray
Path to the model file if it's string or PathLike. Path to the model file if it's string or PathLike.
""" """
cache = cache if cache is not None else []
for d in cache: for d in cache:
if not isinstance(d, DMatrix): if not isinstance(d, DMatrix):
raise TypeError(f'invalid cache item: {type(d).__name__}', cache) raise TypeError(f'invalid cache item: {type(d).__name__}', cache)
dmats = c_array(ctypes.c_void_p, [d.handle for d in cache]) dmats = c_array(ctypes.c_void_p, [d.handle for d in cache])
self.handle = ctypes.c_void_p() self.handle: Optional[ctypes.c_void_p] = ctypes.c_void_p()
_check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(len(cache)), _check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(len(cache)),
ctypes.byref(self.handle))) ctypes.byref(self.handle)))
for d in cache: for d in cache:
@ -1405,12 +1398,12 @@ class Booster(object):
return params return params
def __del__(self): def __del__(self) -> None:
if hasattr(self, 'handle') and self.handle is not None: if hasattr(self, 'handle') and self.handle is not None:
_check_call(_LIB.XGBoosterFree(self.handle)) _check_call(_LIB.XGBoosterFree(self.handle))
self.handle = None self.handle = None
def __getstate__(self): def __getstate__(self) -> Dict:
# can't pickle ctypes pointers, put model content in bytearray # can't pickle ctypes pointers, put model content in bytearray
this = self.__dict__.copy() this = self.__dict__.copy()
handle = this['handle'] handle = this['handle']
@ -1424,7 +1417,7 @@ class Booster(object):
this["handle"] = buf this["handle"] = buf
return this return this
def __setstate__(self, state): def __setstate__(self, state: Dict) -> None:
# reconstruct handle from raw data # reconstruct handle from raw data
handle = state['handle'] handle = state['handle']
if handle is not None: if handle is not None:
@ -1440,7 +1433,7 @@ class Booster(object):
state['handle'] = handle state['handle'] = handle
self.__dict__.update(state) self.__dict__.update(state)
def __getitem__(self, val): def __getitem__(self, val: Union[int, tuple, slice]) -> "Booster":
if isinstance(val, int): if isinstance(val, int):
val = slice(val, val+1) val = slice(val, val+1)
if isinstance(val, tuple): if isinstance(val, tuple):
@ -1461,13 +1454,14 @@ class Booster(object):
step = val.step if val.step is not None else 1 step = val.step if val.step is not None else 1
start = ctypes.c_int(start) c_start = ctypes.c_int(start)
stop = ctypes.c_int(stop) c_stop = ctypes.c_int(stop)
step = ctypes.c_int(step) c_step = ctypes.c_int(step)
sliced_handle = ctypes.c_void_p() sliced_handle = ctypes.c_void_p()
status = _LIB.XGBoosterSlice(self.handle, start, stop, step, status = _LIB.XGBoosterSlice(
ctypes.byref(sliced_handle)) self.handle, c_start, c_stop, c_step, ctypes.byref(sliced_handle)
)
if status == -2: if status == -2:
raise IndexError('Layer index out of range') raise IndexError('Layer index out of range')
_check_call(status) _check_call(status)
@ -1477,7 +1471,7 @@ class Booster(object):
sliced.handle = sliced_handle sliced.handle = sliced_handle
return sliced return sliced
def save_config(self): def save_config(self) -> str:
'''Output internal parameter configuration of Booster as a JSON '''Output internal parameter configuration of Booster as a JSON
string. string.
@ -1489,10 +1483,11 @@ class Booster(object):
self.handle, self.handle,
ctypes.byref(length), ctypes.byref(length),
ctypes.byref(json_string))) ctypes.byref(json_string)))
json_string = json_string.value.decode() # pylint: disable=no-member assert json_string.value is not None
return json_string result = json_string.value.decode() # pylint: disable=no-member
return result
def load_config(self, config): def load_config(self, config: str) -> None:
'''Load configuration returned by `save_config`. '''Load configuration returned by `save_config`.
.. versionadded:: 1.0.0 .. versionadded:: 1.0.0
@ -1502,14 +1497,14 @@ class Booster(object):
self.handle, self.handle,
c_str(config))) c_str(config)))
def __copy__(self): def __copy__(self) -> "Booster":
return self.__deepcopy__(None) return self.__deepcopy__(None)
def __deepcopy__(self, _): def __deepcopy__(self, _) -> "Booster":
'''Return a copy of booster.''' '''Return a copy of booster.'''
return Booster(model_file=self) return Booster(model_file=self)
def copy(self): def copy(self) -> "Booster":
"""Copy the booster object. """Copy the booster object.
Returns Returns
@ -1519,7 +1514,7 @@ class Booster(object):
""" """
return self.__copy__() return self.__copy__()
def attr(self, key): def attr(self, key: str) -> Optional[str]:
"""Get attribute string from the Booster. """Get attribute string from the Booster.
Parameters Parameters
@ -1540,7 +1535,7 @@ class Booster(object):
return py_str(ret.value) return py_str(ret.value)
return None return None
def attributes(self): def attributes(self) -> Dict[str, str]:
"""Get attributes stored in the Booster as a dictionary. """Get attributes stored in the Booster as a dictionary.
Returns Returns
@ -1572,7 +1567,7 @@ class Booster(object):
_check_call(_LIB.XGBoosterSetAttr( _check_call(_LIB.XGBoosterSetAttr(
self.handle, c_str(key), value)) self.handle, c_str(key), value))
def _get_feature_info(self, field: str): def _get_feature_info(self, field: str) -> Optional[List[str]]:
length = c_bst_ulong() length = c_bst_ulong()
sarr = ctypes.POINTER(ctypes.c_char_p)() sarr = ctypes.POINTER(ctypes.c_char_p)()
if not hasattr(self, "handle") or self.handle is None: if not hasattr(self, "handle") or self.handle is None:
@ -1585,22 +1580,6 @@ class Booster(object):
feature_info = from_cstr_to_pystr(sarr, length) feature_info = from_cstr_to_pystr(sarr, length)
return feature_info if feature_info else None return feature_info if feature_info else None
@property
def feature_types(self) -> Optional[List[str]]:
"""Feature types for this booster. Can be directly set by input data or by
assignment.
"""
return self._get_feature_info("feature_type")
@property
def feature_names(self) -> Optional[List[str]]:
"""Feature names for this booster. Can be directly set by input data or by
assignment.
"""
return self._get_feature_info("feature_name")
def _set_feature_info(self, features: Optional[List[str]], field: str) -> None: def _set_feature_info(self, features: Optional[List[str]], field: str) -> None:
if features is not None: if features is not None:
assert isinstance(features, list) assert isinstance(features, list)
@ -1618,14 +1597,30 @@ class Booster(object):
) )
) )
@feature_names.setter @property
def feature_names(self, features: Optional[List[str]]) -> None: def feature_types(self) -> Optional[List[str]]:
self._set_feature_info(features, "feature_name") """Feature types for this booster. Can be directly set by input data or by
assignment.
"""
return self._get_feature_info("feature_type")
@feature_types.setter @feature_types.setter
def feature_types(self, features: Optional[List[str]]) -> None: def feature_types(self, features: Optional[List[str]]) -> None:
self._set_feature_info(features, "feature_type") self._set_feature_info(features, "feature_type")
@property
def feature_names(self) -> Optional[List[str]]:
"""Feature names for this booster. Can be directly set by input data or by
assignment.
"""
return self._get_feature_info("feature_name")
@feature_names.setter
def feature_names(self, features: Optional[List[str]]) -> None:
self._set_feature_info(features, "feature_name")
def set_param(self, params, value=None): def set_param(self, params, value=None):
"""Set parameters into the Booster. """Set parameters into the Booster.
@ -1645,7 +1640,9 @@ class Booster(object):
_check_call(_LIB.XGBoosterSetParam(self.handle, c_str(key), _check_call(_LIB.XGBoosterSetParam(self.handle, c_str(key),
c_str(str(val)))) c_str(str(val))))
def update(self, dtrain, iteration, fobj=None): def update(
self, dtrain: DMatrix, iteration: int, fobj: Optional[Objective] = None
) -> None:
"""Update for one iteration, with objective function calculated """Update for one iteration, with objective function calculated
internally. This function should not be called directly by users. internally. This function should not be called directly by users.
@ -1672,18 +1669,18 @@ class Booster(object):
grad, hess = fobj(pred, dtrain) grad, hess = fobj(pred, dtrain)
self.boost(dtrain, grad, hess) self.boost(dtrain, grad, hess)
def boost(self, dtrain, grad, hess): def boost(self, dtrain: DMatrix, grad: np.ndarray, hess: np.ndarray) -> None:
"""Boost the booster for one iteration, with customized gradient """Boost the booster for one iteration, with customized gradient
statistics. Like :py:func:`xgboost.Booster.update`, this statistics. Like :py:func:`xgboost.Booster.update`, this
function should not be called directly by users. function should not be called directly by users.
Parameters Parameters
---------- ----------
dtrain : DMatrix dtrain :
The training DMatrix. The training DMatrix.
grad : list grad :
The first order of gradient. The first order of gradient.
hess : list hess :
The second order of gradient. The second order of gradient.
""" """
@ -1700,17 +1697,23 @@ class Booster(object):
c_array(ctypes.c_float, hess), c_array(ctypes.c_float, hess),
c_bst_ulong(len(grad)))) c_bst_ulong(len(grad))))
def eval_set(self, evals, iteration=0, feval=None, output_margin=True): def eval_set(
self,
evals: Sequence[Tuple[DMatrix, str]],
iteration: int = 0,
feval: Optional[Metric] = None,
output_margin: bool = True
) -> str:
# pylint: disable=invalid-name # pylint: disable=invalid-name
"""Evaluate a set of data. """Evaluate a set of data.
Parameters Parameters
---------- ----------
evals : list of tuples (DMatrix, string) evals :
List of items to be evaluated. List of items to be evaluated.
iteration : int iteration :
Current iteration. Current iteration.
feval : function feval :
Custom evaluation function. Custom evaluation function.
Returns Returns
@ -1738,6 +1741,7 @@ class Booster(object):
ctypes.byref(msg), ctypes.byref(msg),
) )
) )
assert msg.value is not None
res = msg.value.decode() # pylint: disable=no-member res = msg.value.decode() # pylint: disable=no-member
if feval is not None: if feval is not None:
for dmat, evname in evals: for dmat, evname in evals:
@ -1754,18 +1758,18 @@ class Booster(object):
res += "\t%s-%s:%f" % (evname, name, val) res += "\t%s-%s:%f" % (evname, name, val)
return res return res
def eval(self, data, name='eval', iteration=0): def eval(self, data: DMatrix, name: str = 'eval', iteration: int = 0) -> str:
"""Evaluate the model on mat. """Evaluate the model on mat.
Parameters Parameters
---------- ----------
data : DMatrix data :
The dmatrix storing the input. The dmatrix storing the input.
name : str, optional name :
The name of the dataset. The name of the dataset.
iteration : int, optional iteration :
The current iteration number. The current iteration number.
Returns Returns
@ -2101,7 +2105,7 @@ class Booster(object):
"Data type:" + str(type(data)) + " not supported by inplace prediction." "Data type:" + str(type(data)) + " not supported by inplace prediction."
) )
def save_model(self, fname: Union[str, os.PathLike]): def save_model(self, fname: Union[str, os.PathLike]) -> None:
"""Save the model to a file. """Save the model to a file.
The model is saved in an XGBoost internal format which is universal among the The model is saved in an XGBoost internal format which is universal among the
@ -2124,7 +2128,7 @@ class Booster(object):
else: else:
raise TypeError("fname must be a string or os PathLike") raise TypeError("fname must be a string or os PathLike")
def save_raw(self): def save_raw(self) -> bytearray:
"""Save the model to a in memory buffer representation instead of file. """Save the model to a in memory buffer representation instead of file.
Returns Returns
@ -2232,18 +2236,23 @@ class Booster(object):
if need_close: if need_close:
fout.close() fout.close()
def get_dump(self, fmap='', with_stats=False, dump_format="text"): def get_dump(
self,
fmap: Union[str, os.PathLike] = "",
with_stats: bool = False,
dump_format: str = "text"
) -> List[str]:
"""Returns the model dump as a list of strings. Unlike `save_model`, the """Returns the model dump as a list of strings. Unlike `save_model`, the
output format is primarily used for visualization or interpretation, output format is primarily used for visualization or interpretation,
hence it's more human readable but cannot be loaded back to XGBoost. hence it's more human readable but cannot be loaded back to XGBoost.
Parameters Parameters
---------- ----------
fmap : string or os.PathLike, optional fmap :
Name of the file containing feature map names. Name of the file containing feature map names.
with_stats : bool, optional with_stats :
Controls whether the split statistics are output. Controls whether the split statistics are output.
dump_format : string, optional dump_format :
Format of model dump. Can be 'text', 'json' or 'dot'. Format of model dump. Can be 'text', 'json' or 'dot'.
""" """
@ -2259,7 +2268,9 @@ class Booster(object):
res = from_cstr_to_pystr(sarr, length) res = from_cstr_to_pystr(sarr, length)
return res return res
def get_fscore(self, fmap=''): def get_fscore(
self, fmap: Union[str, os.PathLike] = ""
) -> Dict[str, Union[float, List[float]]]:
"""Get feature importance of each feature. """Get feature importance of each feature.
.. note:: Zero-importance features will not be included .. note:: Zero-importance features will not be included
@ -2269,7 +2280,7 @@ class Booster(object):
Parameters Parameters
---------- ----------
fmap: str or os.PathLike (optional) fmap :
The name of feature map file The name of feature map file
""" """
@ -2299,9 +2310,9 @@ class Booster(object):
Parameters Parameters
---------- ----------
fmap: str or os.PathLike (optional) fmap:
The name of feature map file. The name of feature map file.
importance_type: str, default 'weight' importance_type:
One of the importance types defined above. One of the importance types defined above.
Returns Returns
@ -2343,7 +2354,8 @@ class Booster(object):
results[feat] = float(score) results[feat] = float(score)
return results return results
def trees_to_dataframe(self, fmap=''): # pylint: disable=too-many-statements # pylint: disable=too-many-statements
def trees_to_dataframe(self, fmap: Union[str, os.PathLike] = '') -> DataFrame:
"""Parse a boosted tree model text dump into a pandas DataFrame structure. """Parse a boosted tree model text dump into a pandas DataFrame structure.
This feature is only defined when the decision tree model is chosen as base This feature is only defined when the decision tree model is chosen as base
@ -2370,7 +2382,7 @@ class Booster(object):
node_ids = [] node_ids = []
fids = [] fids = []
splits = [] splits = []
categories = [] categories: List[Optional[float]] = []
y_directs = [] y_directs = []
n_directs = [] n_directs = []
missings = [] missings = []
@ -2444,7 +2456,7 @@ class Booster(object):
# pylint: disable=no-member # pylint: disable=no-member
return df.sort(['Tree', 'Node']).reset_index(drop=True) return df.sort(['Tree', 'Node']).reset_index(drop=True)
def _validate_features(self, data: DMatrix): def _validate_features(self, data: DMatrix) -> None:
""" """
Validate Booster and data's feature_names are identical. Validate Booster and data's feature_names are identical.
Set feature_names and feature_types from DMatrix Set feature_names and feature_types from DMatrix

View File

@ -17,12 +17,13 @@ https://github.com/dask/dask-xgboost
""" """
import platform import platform
import logging import logging
import collections
from contextlib import contextmanager from contextlib import contextmanager
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence
from threading import Thread from threading import Thread
from functools import partial, update_wrapper from functools import partial, update_wrapper
from typing import TYPE_CHECKING, List, Tuple, Callable, Optional, Any, Union, Dict, Set from typing import TYPE_CHECKING, List, Tuple, Callable, Optional, Any, Union, Dict, Set
from typing import Sequence
from typing import Awaitable, Generator, TypeVar from typing import Awaitable, Generator, TypeVar
import numpy import numpy
@ -524,9 +525,9 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
self._feature_names = feature_names self._feature_names = feature_names
self._feature_types = feature_types self._feature_types = feature_types
assert isinstance(self._data, Sequence) assert isinstance(self._data, collections.abc.Sequence)
types = (Sequence, type(None)) types = (collections.abc.Sequence, type(None))
assert isinstance(self._labels, types) assert isinstance(self._labels, types)
assert isinstance(self._weights, types) assert isinstance(self._weights, types)
assert isinstance(self._base_margin, types) assert isinstance(self._base_margin, types)
@ -817,7 +818,7 @@ async def _get_rabit_args(n_workers: int, client: "distributed.Client") -> List[
def _get_workers_from_data( def _get_workers_from_data(
dtrain: DaskDMatrix, dtrain: DaskDMatrix,
evals: Optional[List[Tuple[DaskDMatrix, str]]] evals: Optional[Sequence[Tuple[DaskDMatrix, str]]]
) -> List[str]: ) -> List[str]:
X_worker_map: Set[str] = set(dtrain.worker_map.keys()) X_worker_map: Set[str] = set(dtrain.worker_map.keys())
if evals: if evals:
@ -837,13 +838,13 @@ async def _train_async(
params: Dict[str, Any], params: Dict[str, Any],
dtrain: DaskDMatrix, dtrain: DaskDMatrix,
num_boost_round: int, num_boost_round: int,
evals: Optional[List[Tuple[DaskDMatrix, str]]], evals: Optional[Sequence[Tuple[DaskDMatrix, str]]],
obj: Optional[Objective], obj: Optional[Objective],
feval: Optional[Metric], feval: Optional[Metric],
early_stopping_rounds: Optional[int], early_stopping_rounds: Optional[int],
verbose_eval: Union[int, bool], verbose_eval: Union[int, bool],
xgb_model: Optional[Booster], xgb_model: Optional[Booster],
callbacks: Optional[List[TrainingCallback]], callbacks: Optional[Sequence[TrainingCallback]],
custom_metric: Optional[Metric], custom_metric: Optional[Metric],
) -> Optional[TrainReturnT]: ) -> Optional[TrainReturnT]:
workers = _get_workers_from_data(dtrain, evals) workers = _get_workers_from_data(dtrain, evals)
@ -951,13 +952,13 @@ def train( # pylint: disable=unused-argument
dtrain: DaskDMatrix, dtrain: DaskDMatrix,
num_boost_round: int = 10, num_boost_round: int = 10,
*, *,
evals: Optional[List[Tuple[DaskDMatrix, str]]] = None, evals: Optional[Sequence[Tuple[DaskDMatrix, str]]] = None,
obj: Optional[Objective] = None, obj: Optional[Objective] = None,
feval: Optional[Metric] = None, feval: Optional[Metric] = None,
early_stopping_rounds: Optional[int] = None, early_stopping_rounds: Optional[int] = None,
xgb_model: Optional[Booster] = None, xgb_model: Optional[Booster] = None,
verbose_eval: Union[int, bool] = True, verbose_eval: Union[int, bool] = True,
callbacks: Optional[List[TrainingCallback]] = None, callbacks: Optional[Sequence[TrainingCallback]] = None,
custom_metric: Optional[Metric] = None, custom_metric: Optional[Metric] = None,
) -> Any: ) -> Any:
"""Train XGBoost model. """Train XGBoost model.
@ -1648,15 +1649,15 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
y: _DaskCollection, y: _DaskCollection,
sample_weight: Optional[_DaskCollection], sample_weight: Optional[_DaskCollection],
base_margin: Optional[_DaskCollection], base_margin: Optional[_DaskCollection],
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]], eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]],
eval_metric: Optional[Union[str, List[str], Metric]], eval_metric: Optional[Union[str, Sequence[str], Metric]],
sample_weight_eval_set: Optional[List[_DaskCollection]], sample_weight_eval_set: Optional[Sequence[_DaskCollection]],
base_margin_eval_set: Optional[List[_DaskCollection]], base_margin_eval_set: Optional[Sequence[_DaskCollection]],
early_stopping_rounds: Optional[int], early_stopping_rounds: Optional[int],
verbose: bool, verbose: bool,
xgb_model: Optional[Union[Booster, XGBModel]], xgb_model: Optional[Union[Booster, XGBModel]],
feature_weights: Optional[_DaskCollection], feature_weights: Optional[_DaskCollection],
callbacks: Optional[List[TrainingCallback]], callbacks: Optional[Sequence[TrainingCallback]],
) -> _DaskCollection: ) -> _DaskCollection:
params = self.get_xgb_params() params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices( dtrain, evals = await _async_wrap_evaluation_matrices(
@ -1714,15 +1715,15 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
*, *,
sample_weight: Optional[_DaskCollection] = None, sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None, eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
eval_metric: Optional[Union[str, List[str], Metric]] = None, eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None, early_stopping_rounds: Optional[int] = None,
verbose: bool = True, verbose: bool = True,
xgb_model: Optional[Union[Booster, XGBModel]] = None, xgb_model: Optional[Union[Booster, XGBModel]] = None,
sample_weight_eval_set: Optional[List[_DaskCollection]] = None, sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
base_margin_eval_set: Optional[List[_DaskCollection]] = None, base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None, feature_weights: Optional[_DaskCollection] = None,
callbacks: Optional[List[TrainingCallback]] = None, callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "DaskXGBRegressor": ) -> "DaskXGBRegressor":
_assert_dask_support() _assert_dask_support()
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
@ -1738,15 +1739,15 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
self, X: _DaskCollection, y: _DaskCollection, self, X: _DaskCollection, y: _DaskCollection,
sample_weight: Optional[_DaskCollection], sample_weight: Optional[_DaskCollection],
base_margin: Optional[_DaskCollection], base_margin: Optional[_DaskCollection],
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]], eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]],
eval_metric: Optional[Union[str, List[str], Metric]], eval_metric: Optional[Union[str, Sequence[str], Metric]],
sample_weight_eval_set: Optional[List[_DaskCollection]], sample_weight_eval_set: Optional[Sequence[_DaskCollection]],
base_margin_eval_set: Optional[List[_DaskCollection]], base_margin_eval_set: Optional[Sequence[_DaskCollection]],
early_stopping_rounds: Optional[int], early_stopping_rounds: Optional[int],
verbose: bool, verbose: bool,
xgb_model: Optional[Union[Booster, XGBModel]], xgb_model: Optional[Union[Booster, XGBModel]],
feature_weights: Optional[_DaskCollection], feature_weights: Optional[_DaskCollection],
callbacks: Optional[List[TrainingCallback]] callbacks: Optional[Sequence[TrainingCallback]]
) -> "DaskXGBClassifier": ) -> "DaskXGBClassifier":
params = self.get_xgb_params() params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices( dtrain, evals = await _async_wrap_evaluation_matrices(
@ -1818,15 +1819,15 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
*, *,
sample_weight: Optional[_DaskCollection] = None, sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None, eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
eval_metric: Optional[Union[str, List[str], Metric]] = None, eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None, early_stopping_rounds: Optional[int] = None,
verbose: bool = True, verbose: bool = True,
xgb_model: Optional[Union[Booster, XGBModel]] = None, xgb_model: Optional[Union[Booster, XGBModel]] = None,
sample_weight_eval_set: Optional[List[_DaskCollection]] = None, sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
base_margin_eval_set: Optional[List[_DaskCollection]] = None, base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None, feature_weights: Optional[_DaskCollection] = None,
callbacks: Optional[List[TrainingCallback]] = None callbacks: Optional[Sequence[TrainingCallback]] = None
) -> "DaskXGBClassifier": ) -> "DaskXGBClassifier":
_assert_dask_support() _assert_dask_support()
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
@ -1935,17 +1936,17 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
qid: Optional[_DaskCollection], qid: Optional[_DaskCollection],
sample_weight: Optional[_DaskCollection], sample_weight: Optional[_DaskCollection],
base_margin: Optional[_DaskCollection], base_margin: Optional[_DaskCollection],
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]], eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]],
sample_weight_eval_set: Optional[List[_DaskCollection]], sample_weight_eval_set: Optional[Sequence[_DaskCollection]],
base_margin_eval_set: Optional[List[_DaskCollection]], base_margin_eval_set: Optional[Sequence[_DaskCollection]],
eval_group: Optional[List[_DaskCollection]], eval_group: Optional[Sequence[_DaskCollection]],
eval_qid: Optional[List[_DaskCollection]], eval_qid: Optional[Sequence[_DaskCollection]],
eval_metric: Optional[Union[str, List[str], Metric]], eval_metric: Optional[Union[str, Sequence[str], Metric]],
early_stopping_rounds: Optional[int], early_stopping_rounds: Optional[int],
verbose: bool, verbose: bool,
xgb_model: Optional[Union[XGBModel, Booster]], xgb_model: Optional[Union[XGBModel, Booster]],
feature_weights: Optional[_DaskCollection], feature_weights: Optional[_DaskCollection],
callbacks: Optional[List[TrainingCallback]], callbacks: Optional[Sequence[TrainingCallback]],
) -> "DaskXGBRanker": ) -> "DaskXGBRanker":
msg = "Use `qid` instead of `group` on dask interface." msg = "Use `qid` instead of `group` on dask interface."
if not (group is None and eval_group is None): if not (group is None and eval_group is None):
@ -2010,17 +2011,17 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
qid: Optional[_DaskCollection] = None, qid: Optional[_DaskCollection] = None,
sample_weight: Optional[_DaskCollection] = None, sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None, eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
eval_group: Optional[List[_DaskCollection]] = None, eval_group: Optional[Sequence[_DaskCollection]] = None,
eval_qid: Optional[List[_DaskCollection]] = None, eval_qid: Optional[Sequence[_DaskCollection]] = None,
eval_metric: Optional[Union[str, List[str], Metric]] = None, eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
early_stopping_rounds: int = None, early_stopping_rounds: int = None,
verbose: bool = False, verbose: bool = False,
xgb_model: Optional[Union[XGBModel, Booster]] = None, xgb_model: Optional[Union[XGBModel, Booster]] = None,
sample_weight_eval_set: Optional[List[_DaskCollection]] = None, sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
base_margin_eval_set: Optional[List[_DaskCollection]] = None, base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None, feature_weights: Optional[_DaskCollection] = None,
callbacks: Optional[List[TrainingCallback]] = None callbacks: Optional[Sequence[TrainingCallback]] = None
) -> "DaskXGBRanker": ) -> "DaskXGBRanker":
_assert_dask_support() _assert_dask_support()
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
@ -2077,15 +2078,15 @@ class DaskXGBRFRegressor(DaskXGBRegressor):
*, *,
sample_weight: Optional[_DaskCollection] = None, sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None, eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
eval_metric: Optional[Union[str, List[str], Metric]] = None, eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None, early_stopping_rounds: Optional[int] = None,
verbose: bool = True, verbose: bool = True,
xgb_model: Optional[Union[Booster, XGBModel]] = None, xgb_model: Optional[Union[Booster, XGBModel]] = None,
sample_weight_eval_set: Optional[List[_DaskCollection]] = None, sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
base_margin_eval_set: Optional[List[_DaskCollection]] = None, base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None, feature_weights: Optional[_DaskCollection] = None,
callbacks: Optional[List[TrainingCallback]] = None callbacks: Optional[Sequence[TrainingCallback]] = None
) -> "DaskXGBRFRegressor": ) -> "DaskXGBRFRegressor":
_assert_dask_support() _assert_dask_support()
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
@ -2141,15 +2142,15 @@ class DaskXGBRFClassifier(DaskXGBClassifier):
*, *,
sample_weight: Optional[_DaskCollection] = None, sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None, eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
eval_metric: Optional[Union[str, List[str], Metric]] = None, eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None, early_stopping_rounds: Optional[int] = None,
verbose: bool = True, verbose: bool = True,
xgb_model: Optional[Union[Booster, XGBModel]] = None, xgb_model: Optional[Union[Booster, XGBModel]] = None,
sample_weight_eval_set: Optional[List[_DaskCollection]] = None, sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
base_margin_eval_set: Optional[List[_DaskCollection]] = None, base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None, feature_weights: Optional[_DaskCollection] = None,
callbacks: Optional[List[TrainingCallback]] = None callbacks: Optional[Sequence[TrainingCallback]] = None
) -> "DaskXGBRFClassifier": ) -> "DaskXGBRFClassifier":
_assert_dask_support() _assert_dask_support()
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}

View File

@ -4,7 +4,8 @@ import copy
import warnings import warnings
import json import json
import os import os
from typing import Union, Optional, List, Dict, Callable, Tuple, Any, TypeVar, Type from typing import Union, Optional, List, Dict, Callable, Tuple, Any, TypeVar, Type, cast
from typing import Sequence
import numpy as np import numpy as np
from .core import Booster, DMatrix, XGBoostError from .core import Booster, DMatrix, XGBoostError
@ -36,7 +37,7 @@ class XGBRankerMixIn: # pylint: disable=too-few-public-methods
def _check_rf_callback( def _check_rf_callback(
early_stopping_rounds: Optional[int], early_stopping_rounds: Optional[int],
callbacks: Optional[List[TrainingCallback]], callbacks: Optional[Sequence[TrainingCallback]],
) -> None: ) -> None:
if early_stopping_rounds is not None or callbacks is not None: if early_stopping_rounds is not None or callbacks is not None:
raise NotImplementedError( raise NotImplementedError(
@ -343,14 +344,14 @@ def _wrap_evaluation_matrices(
sample_weight: Optional[Any], sample_weight: Optional[Any],
base_margin: Optional[Any], base_margin: Optional[Any],
feature_weights: Optional[Any], feature_weights: Optional[Any],
eval_set: Optional[List[Tuple[Any, Any]]], eval_set: Optional[Sequence[Tuple[Any, Any]]],
sample_weight_eval_set: Optional[List[Any]], sample_weight_eval_set: Optional[Sequence[Any]],
base_margin_eval_set: Optional[List[Any]], base_margin_eval_set: Optional[Sequence[Any]],
eval_group: Optional[List[Any]], eval_group: Optional[Sequence[Any]],
eval_qid: Optional[List[Any]], eval_qid: Optional[Sequence[Any]],
create_dmatrix: Callable, create_dmatrix: Callable,
enable_categorical: bool, enable_categorical: bool,
) -> Tuple[Any, Optional[List[Tuple[Any, str]]]]: ) -> Tuple[Any, List[Tuple[Any, str]]]:
"""Convert array_like evaluation matrices into DMatrix. Perform validation on the way. """Convert array_like evaluation matrices into DMatrix. Perform validation on the way.
""" """
@ -368,7 +369,7 @@ def _wrap_evaluation_matrices(
n_validation = 0 if eval_set is None else len(eval_set) n_validation = 0 if eval_set is None else len(eval_set)
def validate_or_none(meta: Optional[List], name: str) -> List: def validate_or_none(meta: Optional[Sequence], name: str) -> Sequence:
if meta is None: if meta is None:
return [None] * n_validation return [None] * n_validation
if len(meta) != n_validation: if len(meta) != n_validation:
@ -464,7 +465,7 @@ class XGBModel(XGBModelBase):
missing: float = np.nan, missing: float = np.nan,
num_parallel_tree: Optional[int] = None, num_parallel_tree: Optional[int] = None,
monotone_constraints: Optional[Union[Dict[str, int], str]] = None, monotone_constraints: Optional[Union[Dict[str, int], str]] = None,
interaction_constraints: Optional[Union[str, List[Tuple[str]]]] = None, interaction_constraints: Optional[Union[str, Sequence[Sequence[str]]]] = None,
importance_type: Optional[str] = None, importance_type: Optional[str] = None,
gpu_id: Optional[int] = None, gpu_id: Optional[int] = None,
validate_parameters: Optional[bool] = None, validate_parameters: Optional[bool] = None,
@ -715,7 +716,7 @@ class XGBModel(XGBModelBase):
def _configure_fit( def _configure_fit(
self, self,
booster: Optional[Union[Booster, "XGBModel", str]], booster: Optional[Union[Booster, "XGBModel", str]],
eval_metric: Optional[Union[Callable, str, List[str]]], eval_metric: Optional[Union[Callable, str, Sequence[str]]],
params: Dict[str, Any], params: Dict[str, Any],
early_stopping_rounds: Optional[int], early_stopping_rounds: Optional[int],
) -> Tuple[ ) -> Tuple[
@ -788,10 +789,7 @@ class XGBModel(XGBModelBase):
def _set_evaluation_result(self, evals_result: TrainingCallback.EvalsLog) -> None: def _set_evaluation_result(self, evals_result: TrainingCallback.EvalsLog) -> None:
if evals_result: if evals_result:
for val in evals_result.items(): self.evals_result_ = cast(Dict[str, Dict[str, List[float]]], evals_result)
evals_result_key = list(val[1].keys())[0]
evals_result[val[0]][evals_result_key] = val[1][evals_result_key]
self.evals_result_ = evals_result
@_deprecate_positional_args @_deprecate_positional_args
def fit( def fit(
@ -801,15 +799,15 @@ class XGBModel(XGBModelBase):
*, *,
sample_weight: Optional[array_like] = None, sample_weight: Optional[array_like] = None,
base_margin: Optional[array_like] = None, base_margin: Optional[array_like] = None,
eval_set: Optional[List[Tuple[array_like, array_like]]] = None, eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None,
eval_metric: Optional[Union[str, List[str], Metric]] = None, eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None, early_stopping_rounds: Optional[int] = None,
verbose: Optional[bool] = True, verbose: Optional[bool] = True,
xgb_model: Optional[Union[Booster, str, "XGBModel"]] = None, xgb_model: Optional[Union[Booster, str, "XGBModel"]] = None,
sample_weight_eval_set: Optional[List[array_like]] = None, sample_weight_eval_set: Optional[Sequence[array_like]] = None,
base_margin_eval_set: Optional[List[array_like]] = None, base_margin_eval_set: Optional[Sequence[array_like]] = None,
feature_weights: Optional[array_like] = None, feature_weights: Optional[array_like] = None,
callbacks: Optional[List[TrainingCallback]] = None callbacks: Optional[Sequence[TrainingCallback]] = None
) -> "XGBModel": ) -> "XGBModel":
# pylint: disable=invalid-name,attribute-defined-outside-init # pylint: disable=invalid-name,attribute-defined-outside-init
"""Fit gradient boosting model. """Fit gradient boosting model.
@ -1031,7 +1029,7 @@ class XGBModel(XGBModelBase):
Input features matrix. Input features matrix.
iteration_range : iteration_range :
See :py:meth:`xgboost.XGBRegressor.predict`. See :py:meth:`predict`.
ntree_limit : ntree_limit :
Deprecated, use ``iteration_range`` instead. Deprecated, use ``iteration_range`` instead.
@ -1055,40 +1053,26 @@ class XGBModel(XGBModelBase):
iteration_range=iteration_range iteration_range=iteration_range
) )
def evals_result(self) -> TrainingCallback.EvalsLog: def evals_result(self) -> Dict[str, Dict[str, List[float]]]:
"""Return the evaluation results. """Return the evaluation results.
If **eval_set** is passed to the `fit` function, you can call If **eval_set** is passed to the :py:meth:`fit` function, you can call
``evals_result()`` to get evaluation results for all passed **eval_sets**. ``evals_result()`` to get evaluation results for all passed **eval_sets**. When
When **eval_metric** is also passed to the `fit` function, the **eval_metric** is also passed to the :py:meth:`fit` function, the
**evals_result** will contain the **eval_metrics** passed to the `fit` function. **evals_result** will contain the **eval_metrics** passed to the :py:meth:`fit`
function.
Returns The returned evaluation result is a dictionary:
-------
evals_result : dictionary
Example
-------
.. code-block:: python
param_dist = {'objective':'binary:logistic', 'n_estimators':2}
clf = xgb.XGBModel(**param_dist)
clf.fit(X_train, y_train,
eval_set=[(X_train, y_train), (X_test, y_test)],
eval_metric='logloss',
verbose=True)
evals_result = clf.evals_result()
The variable **evals_result** will contain:
.. code-block:: python .. code-block:: python
{'validation_0': {'logloss': ['0.604835', '0.531479']}, {'validation_0': {'logloss': ['0.604835', '0.531479']},
'validation_1': {'logloss': ['0.41965', '0.17686']}} 'validation_1': {'logloss': ['0.41965', '0.17686']}}
Returns
-------
evals_result
""" """
if getattr(self, "evals_result_", None) is not None: if getattr(self, "evals_result_", None) is not None:
evals_result = self.evals_result_ evals_result = self.evals_result_
@ -1193,8 +1177,8 @@ class XGBModel(XGBModelBase):
.. note:: Intercept is defined only for linear learners .. note:: Intercept is defined only for linear learners
Intercept (bias) is only defined when the linear model is chosen as base Intercept (bias) is only defined when the linear model is chosen as base
learner (`booster=gblinear`). It is not defined for other base learner types, such learner (`booster=gblinear`). It is not defined for other base learner types,
as tree learners (`booster=gbtree`). such as tree learners (`booster=gbtree`).
Returns Returns
------- -------
@ -1251,15 +1235,15 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
*, *,
sample_weight: Optional[array_like] = None, sample_weight: Optional[array_like] = None,
base_margin: Optional[array_like] = None, base_margin: Optional[array_like] = None,
eval_set: Optional[List[Tuple[array_like, array_like]]] = None, eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None,
eval_metric: Optional[Union[str, List[str], Metric]] = None, eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None, early_stopping_rounds: Optional[int] = None,
verbose: Optional[bool] = True, verbose: Optional[bool] = True,
xgb_model: Optional[Union[Booster, str, XGBModel]] = None, xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
sample_weight_eval_set: Optional[List[array_like]] = None, sample_weight_eval_set: Optional[Sequence[array_like]] = None,
base_margin_eval_set: Optional[List[array_like]] = None, base_margin_eval_set: Optional[Sequence[array_like]] = None,
feature_weights: Optional[array_like] = None, feature_weights: Optional[array_like] = None,
callbacks: Optional[List[TrainingCallback]] = None callbacks: Optional[Sequence[TrainingCallback]] = None
) -> "XGBClassifier": ) -> "XGBClassifier":
# pylint: disable = attribute-defined-outside-init,too-many-statements # pylint: disable = attribute-defined-outside-init,too-many-statements
evals_result: TrainingCallback.EvalsLog = {} evals_result: TrainingCallback.EvalsLog = {}
@ -1445,51 +1429,6 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
getattr(self, "n_classes_", None), class_probs, np.vstack getattr(self, "n_classes_", None), class_probs, np.vstack
) )
def evals_result(self) -> TrainingCallback.EvalsLog:
"""Return the evaluation results.
If **eval_set** is passed to the `fit` function, you can call
``evals_result()`` to get evaluation results for all passed **eval_sets**.
When **eval_metric** is also passed as a parameter, the **evals_result** will
contain the **eval_metric** passed to the `fit` function.
Returns
-------
evals_result : dictionary
Example
-------
.. code-block:: python
param_dist = {
'objective':'binary:logistic', 'n_estimators':2, eval_metric="logloss"
}
clf = xgb.XGBClassifier(**param_dist)
clf.fit(X_train, y_train,
eval_set=[(X_train, y_train), (X_test, y_test)],
verbose=True)
evals_result = clf.evals_result()
The variable **evals_result** will contain
.. code-block:: python
{'validation_0': {'logloss': ['0.604835', '0.531479']},
'validation_1': {'logloss': ['0.41965', '0.17686']}}
"""
if self.evals_result_:
evals_result = self.evals_result_
else:
raise XGBoostError('No results.')
return evals_result
@xgboost_model_doc( @xgboost_model_doc(
"scikit-learn API for XGBoost random forest classification.", "scikit-learn API for XGBoost random forest classification.",
@ -1533,15 +1472,15 @@ class XGBRFClassifier(XGBClassifier):
*, *,
sample_weight: Optional[array_like] = None, sample_weight: Optional[array_like] = None,
base_margin: Optional[array_like] = None, base_margin: Optional[array_like] = None,
eval_set: Optional[List[Tuple[array_like, array_like]]] = None, eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None,
eval_metric: Optional[Union[str, List[str], Metric]] = None, eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None, early_stopping_rounds: Optional[int] = None,
verbose: Optional[bool] = True, verbose: Optional[bool] = True,
xgb_model: Optional[Union[Booster, str, XGBModel]] = None, xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
sample_weight_eval_set: Optional[List[array_like]] = None, sample_weight_eval_set: Optional[Sequence[array_like]] = None,
base_margin_eval_set: Optional[List[array_like]] = None, base_margin_eval_set: Optional[Sequence[array_like]] = None,
feature_weights: Optional[array_like] = None, feature_weights: Optional[array_like] = None,
callbacks: Optional[List[TrainingCallback]] = None callbacks: Optional[Sequence[TrainingCallback]] = None
) -> "XGBRFClassifier": ) -> "XGBRFClassifier":
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
_check_rf_callback(early_stopping_rounds, callbacks) _check_rf_callback(early_stopping_rounds, callbacks)
@ -1605,15 +1544,15 @@ class XGBRFRegressor(XGBRegressor):
*, *,
sample_weight: Optional[array_like] = None, sample_weight: Optional[array_like] = None,
base_margin: Optional[array_like] = None, base_margin: Optional[array_like] = None,
eval_set: Optional[List[Tuple[array_like, array_like]]] = None, eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None,
eval_metric: Optional[Union[str, List[str], Metric]] = None, eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None, early_stopping_rounds: Optional[int] = None,
verbose: Optional[bool] = True, verbose: Optional[bool] = True,
xgb_model: Optional[Union[Booster, str, XGBModel]] = None, xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
sample_weight_eval_set: Optional[List[array_like]] = None, sample_weight_eval_set: Optional[Sequence[array_like]] = None,
base_margin_eval_set: Optional[List[array_like]] = None, base_margin_eval_set: Optional[Sequence[array_like]] = None,
feature_weights: Optional[array_like] = None, feature_weights: Optional[array_like] = None,
callbacks: Optional[List[TrainingCallback]] = None callbacks: Optional[Sequence[TrainingCallback]] = None
) -> "XGBRFRegressor": ) -> "XGBRFRegressor":
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
_check_rf_callback(early_stopping_rounds, callbacks) _check_rf_callback(early_stopping_rounds, callbacks)
@ -1682,17 +1621,17 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
qid: Optional[array_like] = None, qid: Optional[array_like] = None,
sample_weight: Optional[array_like] = None, sample_weight: Optional[array_like] = None,
base_margin: Optional[array_like] = None, base_margin: Optional[array_like] = None,
eval_set: Optional[List[Tuple[array_like, array_like]]] = None, eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None,
eval_group: Optional[List[array_like]] = None, eval_group: Optional[Sequence[array_like]] = None,
eval_qid: Optional[List[array_like]] = None, eval_qid: Optional[Sequence[array_like]] = None,
eval_metric: Optional[Union[str, List[str], Metric]] = None, eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None, early_stopping_rounds: Optional[int] = None,
verbose: Optional[bool] = False, verbose: Optional[bool] = False,
xgb_model: Optional[Union[Booster, str, XGBModel]] = None, xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
sample_weight_eval_set: Optional[List[array_like]] = None, sample_weight_eval_set: Optional[Sequence[array_like]] = None,
base_margin_eval_set: Optional[List[array_like]] = None, base_margin_eval_set: Optional[Sequence[array_like]] = None,
feature_weights: Optional[array_like] = None, feature_weights: Optional[array_like] = None,
callbacks: Optional[List[TrainingCallback]] = None callbacks: Optional[Sequence[TrainingCallback]] = None
) -> "XGBRanker": ) -> "XGBRanker":
# pylint: disable = attribute-defined-outside-init,arguments-differ # pylint: disable = attribute-defined-outside-init,arguments-differ
"""Fit gradient boosting ranker """Fit gradient boosting ranker

View File

@ -3,18 +3,20 @@
# pylint: disable=too-many-branches, too-many-statements # pylint: disable=too-many-branches, too-many-statements
"""Training Library containing training routines.""" """Training Library containing training routines."""
import copy import copy
from typing import Optional, List import os
import warnings import warnings
from typing import Optional, Dict, Any, Union, Tuple, cast, Sequence
import numpy as np import numpy as np
from .core import Booster, XGBoostError, _get_booster_layer_trees from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees
from .core import _deprecate_positional_args from .core import Metric, Objective
from .core import Objective, Metric
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold) from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
from . import callback from . import callback
def _assert_new_callback(callbacks: Optional[List[callback.TrainingCallback]]) -> None: def _assert_new_callback(
callbacks: Optional[Sequence[callback.TrainingCallback]]
) -> None:
is_new_callback: bool = not callbacks or all( is_new_callback: bool = not callbacks or all(
isinstance(c, callback.TrainingCallback) for c in callbacks isinstance(c, callback.TrainingCallback) for c in callbacks
) )
@ -44,24 +46,24 @@ def _configure_custom_metric(
def _train_internal( def _train_internal(
params, params: Dict[str, Any],
dtrain, dtrain: DMatrix,
num_boost_round=10, num_boost_round: int = 10,
evals=(), evals: Optional[Sequence[Tuple[DMatrix, str]]] = None,
obj=None, obj: Optional[Objective] = None,
feval=None, feval: Optional[Metric] = None,
custom_metric=None, custom_metric: Optional[Metric] = None,
xgb_model=None, xgb_model: Optional[Union[str, os.PathLike, Booster, bytearray]] = None,
callbacks=None, callbacks: Optional[Sequence[callback.TrainingCallback]] = None,
evals_result=None, evals_result: callback.TrainingCallback.EvalsLog = None,
maximize=None, maximize: Optional[bool] = None,
verbose_eval=None, verbose_eval: Optional[Union[bool, int]] = True,
early_stopping_rounds=None, early_stopping_rounds: Optional[int] = None,
): ) -> Booster:
"""internal training function""" """internal training function"""
callbacks = [] if callbacks is None else copy.copy(callbacks) callbacks = [] if callbacks is None else copy.copy(list(callbacks))
metric_fn = _configure_custom_metric(feval, custom_metric) metric_fn = _configure_custom_metric(feval, custom_metric)
evals = list(evals) evals = list(evals) if evals else []
bst = Booster(params, [dtrain] + [d[0] for d in evals]) bst = Booster(params, [dtrain] + [d[0] for d in evals])
@ -78,7 +80,7 @@ def _train_internal(
callbacks.append( callbacks.append(
callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize) callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
) )
callbacks = callback.CallbackContainer( cb_container = callback.CallbackContainer(
callbacks, callbacks,
metric=metric_fn, metric=metric_fn,
# For old `feval` parameter, the behavior is unchanged. For the new # For old `feval` parameter, the behavior is unchanged. For the new
@ -87,32 +89,32 @@ def _train_internal(
output_margin=callable(obj) or metric_fn is feval, output_margin=callable(obj) or metric_fn is feval,
) )
bst = callbacks.before_training(bst) bst = cb_container.before_training(bst)
for i in range(start_iteration, num_boost_round): for i in range(start_iteration, num_boost_round):
if callbacks.before_iteration(bst, i, dtrain, evals): if cb_container.before_iteration(bst, i, dtrain, evals):
break break
bst.update(dtrain, i, obj) bst.update(dtrain, i, obj)
if callbacks.after_iteration(bst, i, dtrain, evals): if cb_container.after_iteration(bst, i, dtrain, evals):
break break
bst = callbacks.after_training(bst) bst = cb_container.after_training(bst)
if evals_result is not None: if evals_result is not None:
evals_result.update(callbacks.history) evals_result.update(cb_container.history)
# These should be moved into callback functions `after_training`, but until old # These should be moved into callback functions `after_training`, but until old
# callbacks are removed, the train function is the only place for setting the # callbacks are removed, the train function is the only place for setting the
# attributes. # attributes.
num_parallel_tree, _ = _get_booster_layer_trees(bst) num_parallel_tree, _ = _get_booster_layer_trees(bst)
if bst.attr('best_score') is not None: if bst.attr('best_score') is not None:
bst.best_score = float(bst.attr('best_score')) bst.best_score = float(cast(str, bst.attr('best_score')))
bst.best_iteration = int(bst.attr('best_iteration')) bst.best_iteration = int(cast(str, bst.attr('best_iteration')))
# num_class is handled internally # num_class is handled internally
bst.set_attr( bst.set_attr(
best_ntree_limit=str((bst.best_iteration + 1) * num_parallel_tree) best_ntree_limit=str((bst.best_iteration + 1) * num_parallel_tree)
) )
bst.best_ntree_limit = int(bst.attr("best_ntree_limit")) bst.best_ntree_limit = int(cast(str, bst.attr("best_ntree_limit")))
else: else:
# Due to compatibility with version older than 1.4, these attributes are added # Due to compatibility with version older than 1.4, these attributes are added
# to Python object even if early stopping is not used. # to Python object even if early stopping is not used.
@ -126,35 +128,32 @@ def _train_internal(
return bst.copy() return bst.copy()
@_deprecate_positional_args
def train( def train(
params, params: Dict[str, Any],
dtrain, dtrain: DMatrix,
num_boost_round=10, num_boost_round: int = 10,
*, evals: Optional[Sequence[Tuple[DMatrix, str]]] = None,
evals=(),
obj: Optional[Objective] = None, obj: Optional[Objective] = None,
feval=None, feval: Optional[Metric] = None,
maximize=None, maximize: Optional[bool] = None,
early_stopping_rounds=None, early_stopping_rounds: Optional[int] = None,
evals_result=None, evals_result: callback.TrainingCallback.EvalsLog = None,
verbose_eval=True, verbose_eval: Optional[Union[bool, int]] = True,
xgb_model=None, xgb_model: Optional[Union[str, os.PathLike, Booster, bytearray]] = None,
callbacks=None, callbacks: Optional[Sequence[callback.TrainingCallback]] = None,
custom_metric: Optional[Metric] = None, custom_metric: Optional[Metric] = None,
): ) -> Booster:
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
"""Train a booster with given parameters. """Train a booster with given parameters.
Parameters Parameters
---------- ----------
params : dict params :
Booster params. Booster params.
dtrain : DMatrix dtrain :
Data to be trained. Data to be trained.
num_boost_round: int num_boost_round :
Number of boosting iterations. Number of boosting iterations.
evals: list of pairs (DMatrix, string) evals :
List of validation sets for which metrics will evaluated during training. List of validation sets for which metrics will evaluated during training.
Validation metrics will help us track the performance of the model. Validation metrics will help us track the performance of the model.
obj obj
@ -166,7 +165,7 @@ def train(
Use `custom_metric` instead. Use `custom_metric` instead.
maximize : bool maximize : bool
Whether to maximize feval. Whether to maximize feval.
early_stopping_rounds: int early_stopping_rounds :
Activates early stopping. Validation metric needs to improve at least once in Activates early stopping. Validation metric needs to improve at least once in
every **early_stopping_rounds** round(s) to continue training. every **early_stopping_rounds** round(s) to continue training.
Requires at least one item in **evals**. Requires at least one item in **evals**.
@ -178,7 +177,7 @@ def train(
**params**, the last metric will be used for early stopping. **params**, the last metric will be used for early stopping.
If early stopping occurs, the model will have two additional fields: If early stopping occurs, the model will have two additional fields:
``bst.best_score``, ``bst.best_iteration``. ``bst.best_score``, ``bst.best_iteration``.
evals_result: dict evals_result :
This dictionary stores the evaluation results of all the items in watchlist. This dictionary stores the evaluation results of all the items in watchlist.
Example: with a watchlist containing Example: with a watchlist containing
@ -191,7 +190,7 @@ def train(
{'train': {'logloss': ['0.48253', '0.35953']}, {'train': {'logloss': ['0.48253', '0.35953']},
'eval': {'logloss': ['0.480385', '0.357756']}} 'eval': {'logloss': ['0.480385', '0.357756']}}
verbose_eval : bool or int verbose_eval :
Requires at least one item in **evals**. Requires at least one item in **evals**.
If **verbose_eval** is True then the evaluation metric on the validation set is If **verbose_eval** is True then the evaluation metric on the validation set is
printed at each boosting stage. printed at each boosting stage.
@ -200,9 +199,9 @@ def train(
/ the boosting stage found by using **early_stopping_rounds** is also printed. / the boosting stage found by using **early_stopping_rounds** is also printed.
Example: with ``verbose_eval=4`` and at least one item in **evals**, an evaluation metric Example: with ``verbose_eval=4`` and at least one item in **evals**, an evaluation metric
is printed every 4 boosting stages, instead of every boosting stage. is printed every 4 boosting stages, instead of every boosting stage.
xgb_model : file name of stored xgb model or 'Booster' instance xgb_model :
Xgb model to be loaded before training (allows training continuation). Xgb model to be loaded before training (allows training continuation).
callbacks : list of callback functions callbacks :
List of callback functions that are applied at end of each iteration. List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using It is possible to use predefined callbacks by using
:ref:`Callback API <callback_api>`. :ref:`Callback API <callback_api>`.

View File

@ -9,7 +9,9 @@ rng = np.random.RandomState(1994)
class TestInteractionConstraints: class TestInteractionConstraints:
def run_interaction_constraints(self, tree_method, feature_names=None, interaction_constraints='[[0, 1]]'): def run_interaction_constraints(
self, tree_method, feature_names=None, interaction_constraints='[[0, 1]]'
):
x1 = np.random.normal(loc=1.0, scale=1.0, size=1000) x1 = np.random.normal(loc=1.0, scale=1.0, size=1000)
x2 = np.random.normal(loc=1.0, scale=1.0, size=1000) x2 = np.random.normal(loc=1.0, scale=1.0, size=1000)
x3 = np.random.choice([1, 2, 3], size=1000, replace=True) x3 = np.random.choice([1, 2, 3], size=1000, replace=True)