Typehint for subset of core API. (#7348)
This commit is contained in:
parent
45aef75cca
commit
c6769488b3
@ -6,17 +6,18 @@ from abc import ABC
|
||||
import collections
|
||||
import os
|
||||
import pickle
|
||||
from typing import Callable, List, Optional, Union, Dict, Tuple
|
||||
from typing import Callable, List, Optional, Union, Dict, Tuple, TypeVar, cast
|
||||
from typing import Sequence
|
||||
import numpy
|
||||
|
||||
from . import rabit
|
||||
from .core import Booster, XGBoostError
|
||||
from .core import Booster, DMatrix, XGBoostError
|
||||
from .compat import STRING_TYPES
|
||||
|
||||
|
||||
# The new implementation of callback functions.
|
||||
# Breaking:
|
||||
# - reset learning rate no longer accepts total boosting rounds
|
||||
_Score = Union[float, Tuple[float, float]]
|
||||
_ScoreList = Union[List[float], List[Tuple[float, float]]]
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
class TrainingCallback(ABC):
|
||||
@ -26,9 +27,9 @@ class TrainingCallback(ABC):
|
||||
|
||||
'''
|
||||
|
||||
EvalsLog = Dict[str, Dict[str, Union[List[float], List[Tuple[float, float]]]]]
|
||||
EvalsLog = Dict[str, Dict[str, _ScoreList]]
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def before_training(self, model):
|
||||
@ -48,18 +49,18 @@ class TrainingCallback(ABC):
|
||||
return False
|
||||
|
||||
|
||||
def _aggcv(rlist):
|
||||
# pylint: disable=invalid-name
|
||||
def _aggcv(rlist: List[str]) -> List[Tuple[str, float, float]]:
|
||||
# pylint: disable=invalid-name, too-many-locals
|
||||
"""Aggregate cross-validation results.
|
||||
|
||||
"""
|
||||
cvmap = {}
|
||||
cvmap: Dict[Tuple[int, str], List[float]] = {}
|
||||
idx = rlist[0].split()[0]
|
||||
for line in rlist:
|
||||
arr = line.split()
|
||||
arr: List[str] = line.split()
|
||||
assert idx == arr[0]
|
||||
for metric_idx, it in enumerate(arr[1:]):
|
||||
if not isinstance(it, STRING_TYPES):
|
||||
if not isinstance(it, str):
|
||||
it = it.decode()
|
||||
k, v = it.split(':')
|
||||
if (metric_idx, k) not in cvmap:
|
||||
@ -67,16 +68,20 @@ def _aggcv(rlist):
|
||||
cvmap[(metric_idx, k)].append(float(v))
|
||||
msg = idx
|
||||
results = []
|
||||
for (metric_idx, k), v in sorted(cvmap.items(), key=lambda x: x[0][0]):
|
||||
v = numpy.array(v)
|
||||
for (_, name), s in sorted(cvmap.items(), key=lambda x: x[0][0]):
|
||||
as_arr = numpy.array(s)
|
||||
if not isinstance(msg, STRING_TYPES):
|
||||
msg = msg.decode()
|
||||
mean, std = numpy.mean(v), numpy.std(v)
|
||||
results.extend([(k, mean, std)])
|
||||
mean, std = numpy.mean(as_arr), numpy.std(as_arr)
|
||||
results.extend([(name, mean, std)])
|
||||
return results
|
||||
|
||||
|
||||
def _allreduce_metric(score):
|
||||
# allreduce type
|
||||
_ART = TypeVar("_ART")
|
||||
|
||||
|
||||
def _allreduce_metric(score: _ART) -> _ART:
|
||||
'''Helper function for computing customized metric in distributed
|
||||
environment. Not strictly correct as many functions don't use mean value
|
||||
as final result.
|
||||
@ -89,13 +94,13 @@ def _allreduce_metric(score):
|
||||
if isinstance(score, tuple): # has mean and stdv
|
||||
raise ValueError(
|
||||
'xgboost.cv function should not be used in distributed environment.')
|
||||
score = numpy.array([score])
|
||||
score = rabit.allreduce(score, rabit.Op.SUM) / world
|
||||
return score[0]
|
||||
arr = numpy.array([score])
|
||||
arr = rabit.allreduce(arr, rabit.Op.SUM) / world
|
||||
return arr[0]
|
||||
|
||||
|
||||
class CallbackContainer:
|
||||
'''A special callback for invoking a list of other callbacks.
|
||||
'''A special internal callback for invoking a list of other callbacks.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
@ -105,7 +110,7 @@ class CallbackContainer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
callbacks: List[TrainingCallback],
|
||||
callbacks: Sequence[TrainingCallback],
|
||||
metric: Callable = None,
|
||||
output_margin: bool = True,
|
||||
is_cv: bool = False
|
||||
@ -146,33 +151,50 @@ class CallbackContainer:
|
||||
assert isinstance(model, Booster), msg
|
||||
return model
|
||||
|
||||
def before_iteration(self, model, epoch, dtrain, evals) -> bool:
|
||||
def before_iteration(
|
||||
self, model, epoch: int, dtrain: DMatrix, evals: List[Tuple[DMatrix, str]]
|
||||
) -> bool:
|
||||
'''Function called before training iteration.'''
|
||||
return any(c.before_iteration(model, epoch, self.history)
|
||||
for c in self.callbacks)
|
||||
|
||||
def _update_history(self, score, epoch):
|
||||
def _update_history(
|
||||
self,
|
||||
score: Union[List[Tuple[str, float]], List[Tuple[str, float, float]]],
|
||||
epoch: int
|
||||
) -> None:
|
||||
for d in score:
|
||||
name, s = d[0], float(d[1])
|
||||
name: str = d[0]
|
||||
s: float = d[1]
|
||||
if self.is_cv:
|
||||
std = float(d[2])
|
||||
s = (s, std)
|
||||
std = float(cast(Tuple[str, float, float], d)[2])
|
||||
x: _Score = (s, std)
|
||||
else:
|
||||
x = s
|
||||
splited_names = name.split('-')
|
||||
data_name = splited_names[0]
|
||||
metric_name = '-'.join(splited_names[1:])
|
||||
s = _allreduce_metric(s)
|
||||
if data_name in self.history:
|
||||
data_history = self.history[data_name]
|
||||
if metric_name in data_history:
|
||||
data_history[metric_name].append(s)
|
||||
else:
|
||||
data_history[metric_name] = [s]
|
||||
else:
|
||||
x = _allreduce_metric(x)
|
||||
if data_name not in self.history:
|
||||
self.history[data_name] = collections.OrderedDict()
|
||||
self.history[data_name][metric_name] = [s]
|
||||
return False
|
||||
data_history = self.history[data_name]
|
||||
if metric_name not in data_history:
|
||||
data_history[metric_name] = cast(_ScoreList, [])
|
||||
metric_history = data_history[metric_name]
|
||||
if self.is_cv:
|
||||
cast(List[Tuple[float, float]], metric_history).append(
|
||||
cast(Tuple[float, float], x)
|
||||
)
|
||||
else:
|
||||
cast(List[float], metric_history).append(cast(float, x))
|
||||
|
||||
def after_iteration(self, model, epoch, dtrain, evals) -> bool:
|
||||
def after_iteration(
|
||||
self,
|
||||
model,
|
||||
epoch: int,
|
||||
dtrain: DMatrix,
|
||||
evals: Optional[List[Tuple[DMatrix, str]]],
|
||||
) -> bool:
|
||||
'''Function called after training iteration.'''
|
||||
if self.is_cv:
|
||||
scores = model.eval(epoch, self.metric, self._output_margin)
|
||||
@ -183,18 +205,20 @@ class CallbackContainer:
|
||||
evals = [] if evals is None else evals
|
||||
for _, name in evals:
|
||||
assert name.find('-') == -1, 'Dataset name should not contain `-`'
|
||||
score = model.eval_set(evals, epoch, self.metric, self._output_margin)
|
||||
score = score.split()[1:] # into datasets
|
||||
score: str = model.eval_set(evals, epoch, self.metric, self._output_margin)
|
||||
splited = score.split()[1:] # into datasets
|
||||
# split up `test-error:0.1234`
|
||||
score = [tuple(s.split(':')) for s in score]
|
||||
self._update_history(score, epoch)
|
||||
metric_score_str = [tuple(s.split(':')) for s in splited]
|
||||
# convert to float
|
||||
metric_score = [(n, float(s)) for n, s in metric_score_str]
|
||||
self._update_history(metric_score, epoch)
|
||||
ret = any(c.after_iteration(model, epoch, self.history)
|
||||
for c in self.callbacks)
|
||||
return ret
|
||||
|
||||
|
||||
class LearningRateScheduler(TrainingCallback):
|
||||
'''Callback function for scheduling learning rate.
|
||||
"""Callback function for scheduling learning rate.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
@ -207,18 +231,24 @@ class LearningRateScheduler(TrainingCallback):
|
||||
should be a sequence like list or tuple with the same size of boosting
|
||||
rounds.
|
||||
|
||||
'''
|
||||
def __init__(self, learning_rates) -> None:
|
||||
assert callable(learning_rates) or \
|
||||
isinstance(learning_rates, collections.abc.Sequence)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, learning_rates: Union[Callable[[int], float], Sequence[float]]
|
||||
) -> None:
|
||||
assert callable(learning_rates) or isinstance(
|
||||
learning_rates, collections.abc.Sequence
|
||||
)
|
||||
if callable(learning_rates):
|
||||
self.learning_rates = learning_rates
|
||||
else:
|
||||
self.learning_rates = lambda epoch: learning_rates[epoch]
|
||||
self.learning_rates = lambda epoch: cast(Sequence, learning_rates)[epoch]
|
||||
super().__init__()
|
||||
|
||||
def after_iteration(self, model, epoch, evals_log) -> bool:
|
||||
model.set_param('learning_rate', self.learning_rates(epoch))
|
||||
def after_iteration(
|
||||
self, model, epoch: int, evals_log: TrainingCallback.EvalsLog
|
||||
) -> bool:
|
||||
model.set_param("learning_rate", self.learning_rates(epoch))
|
||||
return False
|
||||
|
||||
|
||||
@ -230,17 +260,17 @@ class EarlyStopping(TrainingCallback):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rounds
|
||||
rounds :
|
||||
Early stopping rounds.
|
||||
metric_name
|
||||
metric_name :
|
||||
Name of metric that is used for early stopping.
|
||||
data_name
|
||||
data_name :
|
||||
Name of dataset that is used for early stopping.
|
||||
maximize
|
||||
maximize :
|
||||
Whether to maximize evaluation metric. None means auto (discouraged).
|
||||
save_best
|
||||
save_best :
|
||||
Whether training should return the best model or the last model.
|
||||
min_delta
|
||||
min_delta :
|
||||
Minimum absolute change in score to be qualified as an improvement.
|
||||
|
||||
.. versionadded:: 1.5.0
|
||||
@ -279,8 +309,6 @@ class EarlyStopping(TrainingCallback):
|
||||
if self._min_delta < 0:
|
||||
raise ValueError("min_delta must be greater or equal to 0.")
|
||||
|
||||
self.improve_op = None
|
||||
|
||||
self.current_rounds: int = 0
|
||||
self.best_scores: dict = {}
|
||||
self.starting_round: int = 0
|
||||
@ -290,16 +318,18 @@ class EarlyStopping(TrainingCallback):
|
||||
self.starting_round = model.num_boosted_rounds()
|
||||
return model
|
||||
|
||||
def _update_rounds(self, score, name, metric, model, epoch) -> bool:
|
||||
def get_s(x):
|
||||
def _update_rounds(
|
||||
self, score: _Score, name: str, metric: str, model, epoch: int
|
||||
) -> bool:
|
||||
def get_s(x: _Score) -> float:
|
||||
"""get score if it's cross validation history."""
|
||||
return x[0] if isinstance(x, tuple) else x
|
||||
|
||||
def maximize(new, best):
|
||||
def maximize(new: _Score, best: _Score) -> bool:
|
||||
"""New score should be greater than the old one."""
|
||||
return numpy.greater(get_s(new) - self._min_delta, get_s(best))
|
||||
|
||||
def minimize(new, best):
|
||||
def minimize(new: _Score, best: _Score) -> bool:
|
||||
"""New score should be smaller than the old one."""
|
||||
return numpy.greater(get_s(best) - self._min_delta, get_s(new))
|
||||
|
||||
@ -314,25 +344,25 @@ class EarlyStopping(TrainingCallback):
|
||||
self.maximize = False
|
||||
|
||||
if self.maximize:
|
||||
self.improve_op = maximize
|
||||
improve_op = maximize
|
||||
else:
|
||||
self.improve_op = minimize
|
||||
improve_op = minimize
|
||||
|
||||
assert self.improve_op
|
||||
assert improve_op
|
||||
|
||||
if not self.stopping_history: # First round
|
||||
self.current_rounds = 0
|
||||
self.stopping_history[name] = {}
|
||||
self.stopping_history[name][metric] = [score]
|
||||
self.stopping_history[name][metric] = cast(_ScoreList, [score])
|
||||
self.best_scores[name] = {}
|
||||
self.best_scores[name][metric] = [score]
|
||||
model.set_attr(best_score=str(score), best_iteration=str(epoch))
|
||||
elif not self.improve_op(score, self.best_scores[name][metric][-1]):
|
||||
elif not improve_op(score, self.best_scores[name][metric][-1]):
|
||||
# Not improved
|
||||
self.stopping_history[name][metric].append(score)
|
||||
self.stopping_history[name][metric].append(score) # type: ignore
|
||||
self.current_rounds += 1
|
||||
else: # Improved
|
||||
self.stopping_history[name][metric].append(score)
|
||||
self.stopping_history[name][metric].append(score) # type: ignore
|
||||
self.best_scores[name][metric].append(score)
|
||||
record = self.stopping_history[name][metric][-1]
|
||||
model.set_attr(best_score=str(record), best_iteration=str(epoch))
|
||||
@ -390,16 +420,16 @@ class EvaluationMonitor(TrainingCallback):
|
||||
Parameters
|
||||
----------
|
||||
|
||||
metric : callable
|
||||
metric :
|
||||
Extra user defined metric.
|
||||
rank : int
|
||||
rank :
|
||||
Which worker should be used for printing the result.
|
||||
period : int
|
||||
period :
|
||||
How many epoches between printing.
|
||||
show_stdv : bool
|
||||
show_stdv :
|
||||
Used in cv to show standard deviation. Users should not specify it.
|
||||
'''
|
||||
def __init__(self, rank=0, period=1, show_stdv=False) -> None:
|
||||
def __init__(self, rank: int = 0, period: int = 1, show_stdv: bool = False) -> None:
|
||||
self.printer_rank = rank
|
||||
self.show_stdv = show_stdv
|
||||
self.period = period
|
||||
@ -457,22 +487,27 @@ class TrainingCheckPoint(TrainingCallback):
|
||||
Parameters
|
||||
----------
|
||||
|
||||
directory : os.PathLike
|
||||
directory :
|
||||
Output model directory.
|
||||
name : str
|
||||
name :
|
||||
pattern of output model file. Models will be saved as name_0.json, name_1.json,
|
||||
name_2.json ....
|
||||
as_pickle : boolean
|
||||
as_pickle :
|
||||
When set to Ture, all training parameters will be saved in pickle format, instead
|
||||
of saving only the model.
|
||||
iterations : int
|
||||
iterations :
|
||||
Interval of checkpointing. Checkpointing is slow so setting a larger number can
|
||||
reduce performance hit.
|
||||
|
||||
'''
|
||||
def __init__(self, directory: os.PathLike, name: str = 'model',
|
||||
as_pickle=False, iterations: int = 100):
|
||||
self._path = directory
|
||||
def __init__(
|
||||
self,
|
||||
directory: Union[str, os.PathLike],
|
||||
name: str = 'model',
|
||||
as_pickle: bool = False,
|
||||
iterations: int = 100
|
||||
) -> None:
|
||||
self._path = os.fspath(directory)
|
||||
self._name = name
|
||||
self._as_pickle = as_pickle
|
||||
self._iterations = iterations
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import List, Optional, Any, Union, Dict, TypeVar
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from typing import Callable, Tuple, cast
|
||||
from typing import Callable, Tuple, cast, Sequence
|
||||
import ctypes
|
||||
import os
|
||||
import re
|
||||
@ -31,20 +31,6 @@ class XGBoostError(ValueError):
|
||||
"""Error thrown by xgboost trainer."""
|
||||
|
||||
|
||||
class EarlyStopException(Exception):
|
||||
"""Exception to signal early stopping.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
best_iteration : int
|
||||
The best iteration stopped.
|
||||
"""
|
||||
|
||||
def __init__(self, best_iteration):
|
||||
super().__init__()
|
||||
self.best_iteration = best_iteration
|
||||
|
||||
|
||||
def from_pystr_to_cstr(data: Union[str, List[str]]):
|
||||
"""Convert a Python str or list of Python str to C pointer
|
||||
|
||||
@ -132,18 +118,19 @@ def _log_callback(msg: bytes) -> None:
|
||||
print(py_str(msg))
|
||||
|
||||
|
||||
def _get_log_callback_func():
|
||||
def _get_log_callback_func() -> Callable:
|
||||
"""Wrap log_callback() method in ctypes callback type"""
|
||||
# pylint: disable=invalid-name
|
||||
CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_char_p)
|
||||
return CALLBACK(_log_callback)
|
||||
|
||||
|
||||
def _load_lib():
|
||||
def _load_lib() -> ctypes.CDLL:
|
||||
"""Load xgboost Library."""
|
||||
lib_paths = find_lib_path()
|
||||
if not lib_paths:
|
||||
return None
|
||||
# This happens only when building document.
|
||||
return None # type: ignore
|
||||
try:
|
||||
pathBackup = os.environ['PATH'].split(os.pathsep)
|
||||
except KeyError:
|
||||
@ -190,7 +177,7 @@ Error message(s): {os_error_list}
|
||||
_LIB = _load_lib()
|
||||
|
||||
|
||||
def _check_call(ret):
|
||||
def _check_call(ret: int) -> None:
|
||||
"""Check the return value of C API call
|
||||
|
||||
This function will raise exception when error occurs.
|
||||
@ -234,7 +221,7 @@ def _cuda_array_interface(data) -> bytes:
|
||||
return interface_str
|
||||
|
||||
|
||||
def ctypes2numpy(cptr, length, dtype):
|
||||
def ctypes2numpy(cptr, length, dtype) -> np.ndarray:
|
||||
"""Convert a ctypes pointer array to a numpy array."""
|
||||
ctype = _numpy2ctypes_type(dtype)
|
||||
if not isinstance(cptr, ctypes.POINTER(ctype)):
|
||||
@ -271,7 +258,7 @@ def ctypes2cupy(cptr, length, dtype):
|
||||
return arr
|
||||
|
||||
|
||||
def ctypes2buffer(cptr, length):
|
||||
def ctypes2buffer(cptr, length) -> bytearray:
|
||||
"""Convert ctypes pointer to buffer type."""
|
||||
if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)):
|
||||
raise RuntimeError('expected char pointer')
|
||||
@ -428,7 +415,7 @@ class DataIter: # pylint: disable=too-many-instance-attributes
|
||||
Parameters
|
||||
----------
|
||||
|
||||
data_handle:
|
||||
input_data:
|
||||
A function with same data fields like `data`, `label` with
|
||||
`xgboost.DMatrix`.
|
||||
|
||||
@ -627,7 +614,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
if feature_types is not None:
|
||||
self.feature_types = feature_types
|
||||
|
||||
def _init_from_iter(self, iterator: DataIter, enable_categorical: bool):
|
||||
def _init_from_iter(self, iterator: DataIter, enable_categorical: bool) -> None:
|
||||
it = iterator
|
||||
args = {
|
||||
"missing": self.missing,
|
||||
@ -654,7 +641,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
_check_call(ret)
|
||||
self.handle = handle
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
if hasattr(self, "handle") and self.handle:
|
||||
_check_call(_LIB.XGDMatrixFree(self.handle))
|
||||
self.handle = None
|
||||
@ -699,7 +686,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
dispatch_meta_backend(matrix=self, data=feature_weights,
|
||||
name='feature_weights')
|
||||
|
||||
def get_float_info(self, field):
|
||||
def get_float_info(self, field: str) -> np.ndarray:
|
||||
"""Get float property from the DMatrix.
|
||||
|
||||
Parameters
|
||||
@ -720,7 +707,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
ctypes.byref(ret)))
|
||||
return ctypes2numpy(ret, length.value, np.float32)
|
||||
|
||||
def get_uint_info(self, field):
|
||||
def get_uint_info(self, field: str) -> np.ndarray:
|
||||
"""Get unsigned integer property from the DMatrix.
|
||||
|
||||
Parameters
|
||||
@ -741,7 +728,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
ctypes.byref(ret)))
|
||||
return ctypes2numpy(ret, length.value, np.uint32)
|
||||
|
||||
def set_float_info(self, field, data):
|
||||
def set_float_info(self, field: str, data) -> None:
|
||||
"""Set float type property into the DMatrix.
|
||||
|
||||
Parameters
|
||||
@ -755,7 +742,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
from .data import dispatch_meta_backend
|
||||
dispatch_meta_backend(self, data, field, 'float')
|
||||
|
||||
def set_float_info_npy2d(self, field, data):
|
||||
def set_float_info_npy2d(self, field: str, data) -> None:
|
||||
"""Set float type property into the DMatrix
|
||||
for numpy 2d array input
|
||||
|
||||
@ -770,7 +757,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
from .data import dispatch_meta_backend
|
||||
dispatch_meta_backend(self, data, field, 'float')
|
||||
|
||||
def set_uint_info(self, field, data):
|
||||
def set_uint_info(self, field: str, data) -> None:
|
||||
"""Set uint type property into the DMatrix.
|
||||
|
||||
Parameters
|
||||
@ -784,7 +771,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
from .data import dispatch_meta_backend
|
||||
dispatch_meta_backend(self, data, field, 'uint32')
|
||||
|
||||
def save_binary(self, fname, silent=True):
|
||||
def save_binary(self, fname, silent=True) -> None:
|
||||
"""Save DMatrix to an XGBoost buffer. Saved binary can be later loaded
|
||||
by providing the path to :py:func:`xgboost.DMatrix` as input.
|
||||
|
||||
@ -800,7 +787,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
c_str(fname),
|
||||
ctypes.c_int(silent)))
|
||||
|
||||
def set_label(self, label):
|
||||
def set_label(self, label) -> None:
|
||||
"""Set label of dmatrix
|
||||
|
||||
Parameters
|
||||
@ -811,7 +798,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
from .data import dispatch_meta_backend
|
||||
dispatch_meta_backend(self, label, 'label', 'float')
|
||||
|
||||
def set_weight(self, weight):
|
||||
def set_weight(self, weight) -> None:
|
||||
"""Set weight of each instance.
|
||||
|
||||
Parameters
|
||||
@ -830,7 +817,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
from .data import dispatch_meta_backend
|
||||
dispatch_meta_backend(self, weight, 'weight', 'float')
|
||||
|
||||
def set_base_margin(self, margin):
|
||||
def set_base_margin(self, margin) -> None:
|
||||
"""Set base margin of booster to start from.
|
||||
|
||||
This can be used to specify a prediction value of existing model to be
|
||||
@ -847,7 +834,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
from .data import dispatch_meta_backend
|
||||
dispatch_meta_backend(self, margin, 'base_margin', 'float')
|
||||
|
||||
def set_group(self, group):
|
||||
def set_group(self, group) -> None:
|
||||
"""Set group size of DMatrix (used for ranking).
|
||||
|
||||
Parameters
|
||||
@ -858,7 +845,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
from .data import dispatch_meta_backend
|
||||
dispatch_meta_backend(self, group, 'group', 'uint32')
|
||||
|
||||
def get_label(self):
|
||||
def get_label(self) -> np.ndarray:
|
||||
"""Get the label of the DMatrix.
|
||||
|
||||
Returns
|
||||
@ -867,7 +854,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
"""
|
||||
return self.get_float_info('label')
|
||||
|
||||
def get_weight(self):
|
||||
def get_weight(self) -> np.ndarray:
|
||||
"""Get the weight of the DMatrix.
|
||||
|
||||
Returns
|
||||
@ -876,7 +863,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
"""
|
||||
return self.get_float_info('weight')
|
||||
|
||||
def get_base_margin(self):
|
||||
def get_base_margin(self) -> np.ndarray:
|
||||
"""Get the base margin of the DMatrix.
|
||||
|
||||
Returns
|
||||
@ -885,7 +872,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
"""
|
||||
return self.get_float_info('base_margin')
|
||||
|
||||
def num_row(self):
|
||||
def num_row(self) -> int:
|
||||
"""Get the number of rows in the DMatrix.
|
||||
|
||||
Returns
|
||||
@ -897,7 +884,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
ctypes.byref(ret)))
|
||||
return ret.value
|
||||
|
||||
def num_col(self):
|
||||
def num_col(self) -> int:
|
||||
"""Get the number of columns (features) in the DMatrix.
|
||||
|
||||
Returns
|
||||
@ -1191,7 +1178,7 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
enable_categorical=enable_categorical,
|
||||
)
|
||||
|
||||
def _init(self, data, enable_categorical, **meta):
|
||||
def _init(self, data, enable_categorical: bool, **meta) -> None:
|
||||
from .data import (
|
||||
_is_dlpack,
|
||||
_transform_dlpack,
|
||||
@ -1265,7 +1252,7 @@ def _get_booster_layer_trees(model: "Booster") -> Tuple[int, int]:
|
||||
return num_parallel_tree, num_groups
|
||||
|
||||
|
||||
class Booster(object):
|
||||
class Booster:
|
||||
# pylint: disable=too-many-public-methods
|
||||
"""A Booster of XGBoost.
|
||||
|
||||
@ -1273,7 +1260,12 @@ class Booster(object):
|
||||
training, prediction and evaluation.
|
||||
"""
|
||||
|
||||
def __init__(self, params=None, cache=(), model_file=None):
|
||||
def __init__(
|
||||
self,
|
||||
params: Optional[Dict] = None,
|
||||
cache: Optional[Sequence[DMatrix]] = None,
|
||||
model_file: Optional[Union["Booster", bytearray, os.PathLike, str]] = None
|
||||
) -> None:
|
||||
# pylint: disable=invalid-name
|
||||
"""
|
||||
Parameters
|
||||
@ -1285,12 +1277,13 @@ class Booster(object):
|
||||
model_file : string/os.PathLike/Booster/bytearray
|
||||
Path to the model file if it's string or PathLike.
|
||||
"""
|
||||
cache = cache if cache is not None else []
|
||||
for d in cache:
|
||||
if not isinstance(d, DMatrix):
|
||||
raise TypeError(f'invalid cache item: {type(d).__name__}', cache)
|
||||
|
||||
dmats = c_array(ctypes.c_void_p, [d.handle for d in cache])
|
||||
self.handle = ctypes.c_void_p()
|
||||
self.handle: Optional[ctypes.c_void_p] = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(len(cache)),
|
||||
ctypes.byref(self.handle)))
|
||||
for d in cache:
|
||||
@ -1405,12 +1398,12 @@ class Booster(object):
|
||||
|
||||
return params
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
if hasattr(self, 'handle') and self.handle is not None:
|
||||
_check_call(_LIB.XGBoosterFree(self.handle))
|
||||
self.handle = None
|
||||
|
||||
def __getstate__(self):
|
||||
def __getstate__(self) -> Dict:
|
||||
# can't pickle ctypes pointers, put model content in bytearray
|
||||
this = self.__dict__.copy()
|
||||
handle = this['handle']
|
||||
@ -1424,7 +1417,7 @@ class Booster(object):
|
||||
this["handle"] = buf
|
||||
return this
|
||||
|
||||
def __setstate__(self, state):
|
||||
def __setstate__(self, state: Dict) -> None:
|
||||
# reconstruct handle from raw data
|
||||
handle = state['handle']
|
||||
if handle is not None:
|
||||
@ -1440,7 +1433,7 @@ class Booster(object):
|
||||
state['handle'] = handle
|
||||
self.__dict__.update(state)
|
||||
|
||||
def __getitem__(self, val):
|
||||
def __getitem__(self, val: Union[int, tuple, slice]) -> "Booster":
|
||||
if isinstance(val, int):
|
||||
val = slice(val, val+1)
|
||||
if isinstance(val, tuple):
|
||||
@ -1461,13 +1454,14 @@ class Booster(object):
|
||||
|
||||
step = val.step if val.step is not None else 1
|
||||
|
||||
start = ctypes.c_int(start)
|
||||
stop = ctypes.c_int(stop)
|
||||
step = ctypes.c_int(step)
|
||||
c_start = ctypes.c_int(start)
|
||||
c_stop = ctypes.c_int(stop)
|
||||
c_step = ctypes.c_int(step)
|
||||
|
||||
sliced_handle = ctypes.c_void_p()
|
||||
status = _LIB.XGBoosterSlice(self.handle, start, stop, step,
|
||||
ctypes.byref(sliced_handle))
|
||||
status = _LIB.XGBoosterSlice(
|
||||
self.handle, c_start, c_stop, c_step, ctypes.byref(sliced_handle)
|
||||
)
|
||||
if status == -2:
|
||||
raise IndexError('Layer index out of range')
|
||||
_check_call(status)
|
||||
@ -1477,7 +1471,7 @@ class Booster(object):
|
||||
sliced.handle = sliced_handle
|
||||
return sliced
|
||||
|
||||
def save_config(self):
|
||||
def save_config(self) -> str:
|
||||
'''Output internal parameter configuration of Booster as a JSON
|
||||
string.
|
||||
|
||||
@ -1489,10 +1483,11 @@ class Booster(object):
|
||||
self.handle,
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(json_string)))
|
||||
json_string = json_string.value.decode() # pylint: disable=no-member
|
||||
return json_string
|
||||
assert json_string.value is not None
|
||||
result = json_string.value.decode() # pylint: disable=no-member
|
||||
return result
|
||||
|
||||
def load_config(self, config):
|
||||
def load_config(self, config: str) -> None:
|
||||
'''Load configuration returned by `save_config`.
|
||||
|
||||
.. versionadded:: 1.0.0
|
||||
@ -1502,14 +1497,14 @@ class Booster(object):
|
||||
self.handle,
|
||||
c_str(config)))
|
||||
|
||||
def __copy__(self):
|
||||
def __copy__(self) -> "Booster":
|
||||
return self.__deepcopy__(None)
|
||||
|
||||
def __deepcopy__(self, _):
|
||||
def __deepcopy__(self, _) -> "Booster":
|
||||
'''Return a copy of booster.'''
|
||||
return Booster(model_file=self)
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> "Booster":
|
||||
"""Copy the booster object.
|
||||
|
||||
Returns
|
||||
@ -1519,7 +1514,7 @@ class Booster(object):
|
||||
"""
|
||||
return self.__copy__()
|
||||
|
||||
def attr(self, key):
|
||||
def attr(self, key: str) -> Optional[str]:
|
||||
"""Get attribute string from the Booster.
|
||||
|
||||
Parameters
|
||||
@ -1540,7 +1535,7 @@ class Booster(object):
|
||||
return py_str(ret.value)
|
||||
return None
|
||||
|
||||
def attributes(self):
|
||||
def attributes(self) -> Dict[str, str]:
|
||||
"""Get attributes stored in the Booster as a dictionary.
|
||||
|
||||
Returns
|
||||
@ -1572,7 +1567,7 @@ class Booster(object):
|
||||
_check_call(_LIB.XGBoosterSetAttr(
|
||||
self.handle, c_str(key), value))
|
||||
|
||||
def _get_feature_info(self, field: str):
|
||||
def _get_feature_info(self, field: str) -> Optional[List[str]]:
|
||||
length = c_bst_ulong()
|
||||
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
||||
if not hasattr(self, "handle") or self.handle is None:
|
||||
@ -1585,22 +1580,6 @@ class Booster(object):
|
||||
feature_info = from_cstr_to_pystr(sarr, length)
|
||||
return feature_info if feature_info else None
|
||||
|
||||
@property
|
||||
def feature_types(self) -> Optional[List[str]]:
|
||||
"""Feature types for this booster. Can be directly set by input data or by
|
||||
assignment.
|
||||
|
||||
"""
|
||||
return self._get_feature_info("feature_type")
|
||||
|
||||
@property
|
||||
def feature_names(self) -> Optional[List[str]]:
|
||||
"""Feature names for this booster. Can be directly set by input data or by
|
||||
assignment.
|
||||
|
||||
"""
|
||||
return self._get_feature_info("feature_name")
|
||||
|
||||
def _set_feature_info(self, features: Optional[List[str]], field: str) -> None:
|
||||
if features is not None:
|
||||
assert isinstance(features, list)
|
||||
@ -1618,14 +1597,30 @@ class Booster(object):
|
||||
)
|
||||
)
|
||||
|
||||
@feature_names.setter
|
||||
def feature_names(self, features: Optional[List[str]]) -> None:
|
||||
self._set_feature_info(features, "feature_name")
|
||||
@property
|
||||
def feature_types(self) -> Optional[List[str]]:
|
||||
"""Feature types for this booster. Can be directly set by input data or by
|
||||
assignment.
|
||||
|
||||
"""
|
||||
return self._get_feature_info("feature_type")
|
||||
|
||||
@feature_types.setter
|
||||
def feature_types(self, features: Optional[List[str]]) -> None:
|
||||
self._set_feature_info(features, "feature_type")
|
||||
|
||||
@property
|
||||
def feature_names(self) -> Optional[List[str]]:
|
||||
"""Feature names for this booster. Can be directly set by input data or by
|
||||
assignment.
|
||||
|
||||
"""
|
||||
return self._get_feature_info("feature_name")
|
||||
|
||||
@feature_names.setter
|
||||
def feature_names(self, features: Optional[List[str]]) -> None:
|
||||
self._set_feature_info(features, "feature_name")
|
||||
|
||||
def set_param(self, params, value=None):
|
||||
"""Set parameters into the Booster.
|
||||
|
||||
@ -1645,7 +1640,9 @@ class Booster(object):
|
||||
_check_call(_LIB.XGBoosterSetParam(self.handle, c_str(key),
|
||||
c_str(str(val))))
|
||||
|
||||
def update(self, dtrain, iteration, fobj=None):
|
||||
def update(
|
||||
self, dtrain: DMatrix, iteration: int, fobj: Optional[Objective] = None
|
||||
) -> None:
|
||||
"""Update for one iteration, with objective function calculated
|
||||
internally. This function should not be called directly by users.
|
||||
|
||||
@ -1672,18 +1669,18 @@ class Booster(object):
|
||||
grad, hess = fobj(pred, dtrain)
|
||||
self.boost(dtrain, grad, hess)
|
||||
|
||||
def boost(self, dtrain, grad, hess):
|
||||
def boost(self, dtrain: DMatrix, grad: np.ndarray, hess: np.ndarray) -> None:
|
||||
"""Boost the booster for one iteration, with customized gradient
|
||||
statistics. Like :py:func:`xgboost.Booster.update`, this
|
||||
function should not be called directly by users.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dtrain : DMatrix
|
||||
dtrain :
|
||||
The training DMatrix.
|
||||
grad : list
|
||||
grad :
|
||||
The first order of gradient.
|
||||
hess : list
|
||||
hess :
|
||||
The second order of gradient.
|
||||
|
||||
"""
|
||||
@ -1700,17 +1697,23 @@ class Booster(object):
|
||||
c_array(ctypes.c_float, hess),
|
||||
c_bst_ulong(len(grad))))
|
||||
|
||||
def eval_set(self, evals, iteration=0, feval=None, output_margin=True):
|
||||
def eval_set(
|
||||
self,
|
||||
evals: Sequence[Tuple[DMatrix, str]],
|
||||
iteration: int = 0,
|
||||
feval: Optional[Metric] = None,
|
||||
output_margin: bool = True
|
||||
) -> str:
|
||||
# pylint: disable=invalid-name
|
||||
"""Evaluate a set of data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
evals : list of tuples (DMatrix, string)
|
||||
evals :
|
||||
List of items to be evaluated.
|
||||
iteration : int
|
||||
iteration :
|
||||
Current iteration.
|
||||
feval : function
|
||||
feval :
|
||||
Custom evaluation function.
|
||||
|
||||
Returns
|
||||
@ -1738,6 +1741,7 @@ class Booster(object):
|
||||
ctypes.byref(msg),
|
||||
)
|
||||
)
|
||||
assert msg.value is not None
|
||||
res = msg.value.decode() # pylint: disable=no-member
|
||||
if feval is not None:
|
||||
for dmat, evname in evals:
|
||||
@ -1754,18 +1758,18 @@ class Booster(object):
|
||||
res += "\t%s-%s:%f" % (evname, name, val)
|
||||
return res
|
||||
|
||||
def eval(self, data, name='eval', iteration=0):
|
||||
def eval(self, data: DMatrix, name: str = 'eval', iteration: int = 0) -> str:
|
||||
"""Evaluate the model on mat.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : DMatrix
|
||||
data :
|
||||
The dmatrix storing the input.
|
||||
|
||||
name : str, optional
|
||||
name :
|
||||
The name of the dataset.
|
||||
|
||||
iteration : int, optional
|
||||
iteration :
|
||||
The current iteration number.
|
||||
|
||||
Returns
|
||||
@ -2101,7 +2105,7 @@ class Booster(object):
|
||||
"Data type:" + str(type(data)) + " not supported by inplace prediction."
|
||||
)
|
||||
|
||||
def save_model(self, fname: Union[str, os.PathLike]):
|
||||
def save_model(self, fname: Union[str, os.PathLike]) -> None:
|
||||
"""Save the model to a file.
|
||||
|
||||
The model is saved in an XGBoost internal format which is universal among the
|
||||
@ -2124,7 +2128,7 @@ class Booster(object):
|
||||
else:
|
||||
raise TypeError("fname must be a string or os PathLike")
|
||||
|
||||
def save_raw(self):
|
||||
def save_raw(self) -> bytearray:
|
||||
"""Save the model to a in memory buffer representation instead of file.
|
||||
|
||||
Returns
|
||||
@ -2232,18 +2236,23 @@ class Booster(object):
|
||||
if need_close:
|
||||
fout.close()
|
||||
|
||||
def get_dump(self, fmap='', with_stats=False, dump_format="text"):
|
||||
def get_dump(
|
||||
self,
|
||||
fmap: Union[str, os.PathLike] = "",
|
||||
with_stats: bool = False,
|
||||
dump_format: str = "text"
|
||||
) -> List[str]:
|
||||
"""Returns the model dump as a list of strings. Unlike `save_model`, the
|
||||
output format is primarily used for visualization or interpretation,
|
||||
hence it's more human readable but cannot be loaded back to XGBoost.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fmap : string or os.PathLike, optional
|
||||
fmap :
|
||||
Name of the file containing feature map names.
|
||||
with_stats : bool, optional
|
||||
with_stats :
|
||||
Controls whether the split statistics are output.
|
||||
dump_format : string, optional
|
||||
dump_format :
|
||||
Format of model dump. Can be 'text', 'json' or 'dot'.
|
||||
|
||||
"""
|
||||
@ -2259,7 +2268,9 @@ class Booster(object):
|
||||
res = from_cstr_to_pystr(sarr, length)
|
||||
return res
|
||||
|
||||
def get_fscore(self, fmap=''):
|
||||
def get_fscore(
|
||||
self, fmap: Union[str, os.PathLike] = ""
|
||||
) -> Dict[str, Union[float, List[float]]]:
|
||||
"""Get feature importance of each feature.
|
||||
|
||||
.. note:: Zero-importance features will not be included
|
||||
@ -2269,7 +2280,7 @@ class Booster(object):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fmap: str or os.PathLike (optional)
|
||||
fmap :
|
||||
The name of feature map file
|
||||
"""
|
||||
|
||||
@ -2299,9 +2310,9 @@ class Booster(object):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fmap: str or os.PathLike (optional)
|
||||
fmap:
|
||||
The name of feature map file.
|
||||
importance_type: str, default 'weight'
|
||||
importance_type:
|
||||
One of the importance types defined above.
|
||||
|
||||
Returns
|
||||
@ -2343,7 +2354,8 @@ class Booster(object):
|
||||
results[feat] = float(score)
|
||||
return results
|
||||
|
||||
def trees_to_dataframe(self, fmap=''): # pylint: disable=too-many-statements
|
||||
# pylint: disable=too-many-statements
|
||||
def trees_to_dataframe(self, fmap: Union[str, os.PathLike] = '') -> DataFrame:
|
||||
"""Parse a boosted tree model text dump into a pandas DataFrame structure.
|
||||
|
||||
This feature is only defined when the decision tree model is chosen as base
|
||||
@ -2370,7 +2382,7 @@ class Booster(object):
|
||||
node_ids = []
|
||||
fids = []
|
||||
splits = []
|
||||
categories = []
|
||||
categories: List[Optional[float]] = []
|
||||
y_directs = []
|
||||
n_directs = []
|
||||
missings = []
|
||||
@ -2444,7 +2456,7 @@ class Booster(object):
|
||||
# pylint: disable=no-member
|
||||
return df.sort(['Tree', 'Node']).reset_index(drop=True)
|
||||
|
||||
def _validate_features(self, data: DMatrix):
|
||||
def _validate_features(self, data: DMatrix) -> None:
|
||||
"""
|
||||
Validate Booster and data's feature_names are identical.
|
||||
Set feature_names and feature_types from DMatrix
|
||||
|
||||
@ -17,12 +17,13 @@ https://github.com/dask/dask-xgboost
|
||||
"""
|
||||
import platform
|
||||
import logging
|
||||
import collections
|
||||
from contextlib import contextmanager
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from threading import Thread
|
||||
from functools import partial, update_wrapper
|
||||
from typing import TYPE_CHECKING, List, Tuple, Callable, Optional, Any, Union, Dict, Set
|
||||
from typing import Sequence
|
||||
from typing import Awaitable, Generator, TypeVar
|
||||
|
||||
import numpy
|
||||
@ -524,9 +525,9 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
||||
self._feature_names = feature_names
|
||||
self._feature_types = feature_types
|
||||
|
||||
assert isinstance(self._data, Sequence)
|
||||
assert isinstance(self._data, collections.abc.Sequence)
|
||||
|
||||
types = (Sequence, type(None))
|
||||
types = (collections.abc.Sequence, type(None))
|
||||
assert isinstance(self._labels, types)
|
||||
assert isinstance(self._weights, types)
|
||||
assert isinstance(self._base_margin, types)
|
||||
@ -817,7 +818,7 @@ async def _get_rabit_args(n_workers: int, client: "distributed.Client") -> List[
|
||||
|
||||
def _get_workers_from_data(
|
||||
dtrain: DaskDMatrix,
|
||||
evals: Optional[List[Tuple[DaskDMatrix, str]]]
|
||||
evals: Optional[Sequence[Tuple[DaskDMatrix, str]]]
|
||||
) -> List[str]:
|
||||
X_worker_map: Set[str] = set(dtrain.worker_map.keys())
|
||||
if evals:
|
||||
@ -837,13 +838,13 @@ async def _train_async(
|
||||
params: Dict[str, Any],
|
||||
dtrain: DaskDMatrix,
|
||||
num_boost_round: int,
|
||||
evals: Optional[List[Tuple[DaskDMatrix, str]]],
|
||||
evals: Optional[Sequence[Tuple[DaskDMatrix, str]]],
|
||||
obj: Optional[Objective],
|
||||
feval: Optional[Metric],
|
||||
early_stopping_rounds: Optional[int],
|
||||
verbose_eval: Union[int, bool],
|
||||
xgb_model: Optional[Booster],
|
||||
callbacks: Optional[List[TrainingCallback]],
|
||||
callbacks: Optional[Sequence[TrainingCallback]],
|
||||
custom_metric: Optional[Metric],
|
||||
) -> Optional[TrainReturnT]:
|
||||
workers = _get_workers_from_data(dtrain, evals)
|
||||
@ -951,13 +952,13 @@ def train( # pylint: disable=unused-argument
|
||||
dtrain: DaskDMatrix,
|
||||
num_boost_round: int = 10,
|
||||
*,
|
||||
evals: Optional[List[Tuple[DaskDMatrix, str]]] = None,
|
||||
evals: Optional[Sequence[Tuple[DaskDMatrix, str]]] = None,
|
||||
obj: Optional[Objective] = None,
|
||||
feval: Optional[Metric] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
xgb_model: Optional[Booster] = None,
|
||||
verbose_eval: Union[int, bool] = True,
|
||||
callbacks: Optional[List[TrainingCallback]] = None,
|
||||
callbacks: Optional[Sequence[TrainingCallback]] = None,
|
||||
custom_metric: Optional[Metric] = None,
|
||||
) -> Any:
|
||||
"""Train XGBoost model.
|
||||
@ -1648,15 +1649,15 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
y: _DaskCollection,
|
||||
sample_weight: Optional[_DaskCollection],
|
||||
base_margin: Optional[_DaskCollection],
|
||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
|
||||
eval_metric: Optional[Union[str, List[str], Metric]],
|
||||
sample_weight_eval_set: Optional[List[_DaskCollection]],
|
||||
base_margin_eval_set: Optional[List[_DaskCollection]],
|
||||
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]],
|
||||
eval_metric: Optional[Union[str, Sequence[str], Metric]],
|
||||
sample_weight_eval_set: Optional[Sequence[_DaskCollection]],
|
||||
base_margin_eval_set: Optional[Sequence[_DaskCollection]],
|
||||
early_stopping_rounds: Optional[int],
|
||||
verbose: bool,
|
||||
xgb_model: Optional[Union[Booster, XGBModel]],
|
||||
feature_weights: Optional[_DaskCollection],
|
||||
callbacks: Optional[List[TrainingCallback]],
|
||||
callbacks: Optional[Sequence[TrainingCallback]],
|
||||
) -> _DaskCollection:
|
||||
params = self.get_xgb_params()
|
||||
dtrain, evals = await _async_wrap_evaluation_matrices(
|
||||
@ -1714,15 +1715,15 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
*,
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
||||
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||
eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
verbose: bool = True,
|
||||
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
base_margin_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
callbacks: Optional[List[TrainingCallback]] = None,
|
||||
callbacks: Optional[Sequence[TrainingCallback]] = None,
|
||||
) -> "DaskXGBRegressor":
|
||||
_assert_dask_support()
|
||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||
@ -1738,15 +1739,15 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
self, X: _DaskCollection, y: _DaskCollection,
|
||||
sample_weight: Optional[_DaskCollection],
|
||||
base_margin: Optional[_DaskCollection],
|
||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
|
||||
eval_metric: Optional[Union[str, List[str], Metric]],
|
||||
sample_weight_eval_set: Optional[List[_DaskCollection]],
|
||||
base_margin_eval_set: Optional[List[_DaskCollection]],
|
||||
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]],
|
||||
eval_metric: Optional[Union[str, Sequence[str], Metric]],
|
||||
sample_weight_eval_set: Optional[Sequence[_DaskCollection]],
|
||||
base_margin_eval_set: Optional[Sequence[_DaskCollection]],
|
||||
early_stopping_rounds: Optional[int],
|
||||
verbose: bool,
|
||||
xgb_model: Optional[Union[Booster, XGBModel]],
|
||||
feature_weights: Optional[_DaskCollection],
|
||||
callbacks: Optional[List[TrainingCallback]]
|
||||
callbacks: Optional[Sequence[TrainingCallback]]
|
||||
) -> "DaskXGBClassifier":
|
||||
params = self.get_xgb_params()
|
||||
dtrain, evals = await _async_wrap_evaluation_matrices(
|
||||
@ -1818,15 +1819,15 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
*,
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
||||
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||
eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
verbose: bool = True,
|
||||
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
base_margin_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
callbacks: Optional[List[TrainingCallback]] = None
|
||||
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||
) -> "DaskXGBClassifier":
|
||||
_assert_dask_support()
|
||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||
@ -1935,17 +1936,17 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
||||
qid: Optional[_DaskCollection],
|
||||
sample_weight: Optional[_DaskCollection],
|
||||
base_margin: Optional[_DaskCollection],
|
||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
|
||||
sample_weight_eval_set: Optional[List[_DaskCollection]],
|
||||
base_margin_eval_set: Optional[List[_DaskCollection]],
|
||||
eval_group: Optional[List[_DaskCollection]],
|
||||
eval_qid: Optional[List[_DaskCollection]],
|
||||
eval_metric: Optional[Union[str, List[str], Metric]],
|
||||
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]],
|
||||
sample_weight_eval_set: Optional[Sequence[_DaskCollection]],
|
||||
base_margin_eval_set: Optional[Sequence[_DaskCollection]],
|
||||
eval_group: Optional[Sequence[_DaskCollection]],
|
||||
eval_qid: Optional[Sequence[_DaskCollection]],
|
||||
eval_metric: Optional[Union[str, Sequence[str], Metric]],
|
||||
early_stopping_rounds: Optional[int],
|
||||
verbose: bool,
|
||||
xgb_model: Optional[Union[XGBModel, Booster]],
|
||||
feature_weights: Optional[_DaskCollection],
|
||||
callbacks: Optional[List[TrainingCallback]],
|
||||
callbacks: Optional[Sequence[TrainingCallback]],
|
||||
) -> "DaskXGBRanker":
|
||||
msg = "Use `qid` instead of `group` on dask interface."
|
||||
if not (group is None and eval_group is None):
|
||||
@ -2010,17 +2011,17 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
||||
qid: Optional[_DaskCollection] = None,
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||
eval_group: Optional[List[_DaskCollection]] = None,
|
||||
eval_qid: Optional[List[_DaskCollection]] = None,
|
||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
||||
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||
eval_group: Optional[Sequence[_DaskCollection]] = None,
|
||||
eval_qid: Optional[Sequence[_DaskCollection]] = None,
|
||||
eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
|
||||
early_stopping_rounds: int = None,
|
||||
verbose: bool = False,
|
||||
xgb_model: Optional[Union[XGBModel, Booster]] = None,
|
||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
base_margin_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
callbacks: Optional[List[TrainingCallback]] = None
|
||||
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||
) -> "DaskXGBRanker":
|
||||
_assert_dask_support()
|
||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||
@ -2077,15 +2078,15 @@ class DaskXGBRFRegressor(DaskXGBRegressor):
|
||||
*,
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
||||
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||
eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
verbose: bool = True,
|
||||
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
base_margin_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
callbacks: Optional[List[TrainingCallback]] = None
|
||||
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||
) -> "DaskXGBRFRegressor":
|
||||
_assert_dask_support()
|
||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||
@ -2141,15 +2142,15 @@ class DaskXGBRFClassifier(DaskXGBClassifier):
|
||||
*,
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
||||
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||
eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
verbose: bool = True,
|
||||
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
base_margin_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
callbacks: Optional[List[TrainingCallback]] = None
|
||||
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||
) -> "DaskXGBRFClassifier":
|
||||
_assert_dask_support()
|
||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||
|
||||
@ -4,7 +4,8 @@ import copy
|
||||
import warnings
|
||||
import json
|
||||
import os
|
||||
from typing import Union, Optional, List, Dict, Callable, Tuple, Any, TypeVar, Type
|
||||
from typing import Union, Optional, List, Dict, Callable, Tuple, Any, TypeVar, Type, cast
|
||||
from typing import Sequence
|
||||
import numpy as np
|
||||
|
||||
from .core import Booster, DMatrix, XGBoostError
|
||||
@ -36,7 +37,7 @@ class XGBRankerMixIn: # pylint: disable=too-few-public-methods
|
||||
|
||||
def _check_rf_callback(
|
||||
early_stopping_rounds: Optional[int],
|
||||
callbacks: Optional[List[TrainingCallback]],
|
||||
callbacks: Optional[Sequence[TrainingCallback]],
|
||||
) -> None:
|
||||
if early_stopping_rounds is not None or callbacks is not None:
|
||||
raise NotImplementedError(
|
||||
@ -343,14 +344,14 @@ def _wrap_evaluation_matrices(
|
||||
sample_weight: Optional[Any],
|
||||
base_margin: Optional[Any],
|
||||
feature_weights: Optional[Any],
|
||||
eval_set: Optional[List[Tuple[Any, Any]]],
|
||||
sample_weight_eval_set: Optional[List[Any]],
|
||||
base_margin_eval_set: Optional[List[Any]],
|
||||
eval_group: Optional[List[Any]],
|
||||
eval_qid: Optional[List[Any]],
|
||||
eval_set: Optional[Sequence[Tuple[Any, Any]]],
|
||||
sample_weight_eval_set: Optional[Sequence[Any]],
|
||||
base_margin_eval_set: Optional[Sequence[Any]],
|
||||
eval_group: Optional[Sequence[Any]],
|
||||
eval_qid: Optional[Sequence[Any]],
|
||||
create_dmatrix: Callable,
|
||||
enable_categorical: bool,
|
||||
) -> Tuple[Any, Optional[List[Tuple[Any, str]]]]:
|
||||
) -> Tuple[Any, List[Tuple[Any, str]]]:
|
||||
"""Convert array_like evaluation matrices into DMatrix. Perform validation on the way.
|
||||
|
||||
"""
|
||||
@ -368,7 +369,7 @@ def _wrap_evaluation_matrices(
|
||||
|
||||
n_validation = 0 if eval_set is None else len(eval_set)
|
||||
|
||||
def validate_or_none(meta: Optional[List], name: str) -> List:
|
||||
def validate_or_none(meta: Optional[Sequence], name: str) -> Sequence:
|
||||
if meta is None:
|
||||
return [None] * n_validation
|
||||
if len(meta) != n_validation:
|
||||
@ -464,7 +465,7 @@ class XGBModel(XGBModelBase):
|
||||
missing: float = np.nan,
|
||||
num_parallel_tree: Optional[int] = None,
|
||||
monotone_constraints: Optional[Union[Dict[str, int], str]] = None,
|
||||
interaction_constraints: Optional[Union[str, List[Tuple[str]]]] = None,
|
||||
interaction_constraints: Optional[Union[str, Sequence[Sequence[str]]]] = None,
|
||||
importance_type: Optional[str] = None,
|
||||
gpu_id: Optional[int] = None,
|
||||
validate_parameters: Optional[bool] = None,
|
||||
@ -715,7 +716,7 @@ class XGBModel(XGBModelBase):
|
||||
def _configure_fit(
|
||||
self,
|
||||
booster: Optional[Union[Booster, "XGBModel", str]],
|
||||
eval_metric: Optional[Union[Callable, str, List[str]]],
|
||||
eval_metric: Optional[Union[Callable, str, Sequence[str]]],
|
||||
params: Dict[str, Any],
|
||||
early_stopping_rounds: Optional[int],
|
||||
) -> Tuple[
|
||||
@ -788,10 +789,7 @@ class XGBModel(XGBModelBase):
|
||||
|
||||
def _set_evaluation_result(self, evals_result: TrainingCallback.EvalsLog) -> None:
|
||||
if evals_result:
|
||||
for val in evals_result.items():
|
||||
evals_result_key = list(val[1].keys())[0]
|
||||
evals_result[val[0]][evals_result_key] = val[1][evals_result_key]
|
||||
self.evals_result_ = evals_result
|
||||
self.evals_result_ = cast(Dict[str, Dict[str, List[float]]], evals_result)
|
||||
|
||||
@_deprecate_positional_args
|
||||
def fit(
|
||||
@ -801,15 +799,15 @@ class XGBModel(XGBModelBase):
|
||||
*,
|
||||
sample_weight: Optional[array_like] = None,
|
||||
base_margin: Optional[array_like] = None,
|
||||
eval_set: Optional[List[Tuple[array_like, array_like]]] = None,
|
||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
||||
eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None,
|
||||
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
verbose: Optional[bool] = True,
|
||||
xgb_model: Optional[Union[Booster, str, "XGBModel"]] = None,
|
||||
sample_weight_eval_set: Optional[List[array_like]] = None,
|
||||
base_margin_eval_set: Optional[List[array_like]] = None,
|
||||
sample_weight_eval_set: Optional[Sequence[array_like]] = None,
|
||||
base_margin_eval_set: Optional[Sequence[array_like]] = None,
|
||||
feature_weights: Optional[array_like] = None,
|
||||
callbacks: Optional[List[TrainingCallback]] = None
|
||||
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||
) -> "XGBModel":
|
||||
# pylint: disable=invalid-name,attribute-defined-outside-init
|
||||
"""Fit gradient boosting model.
|
||||
@ -1031,7 +1029,7 @@ class XGBModel(XGBModelBase):
|
||||
Input features matrix.
|
||||
|
||||
iteration_range :
|
||||
See :py:meth:`xgboost.XGBRegressor.predict`.
|
||||
See :py:meth:`predict`.
|
||||
|
||||
ntree_limit :
|
||||
Deprecated, use ``iteration_range`` instead.
|
||||
@ -1055,40 +1053,26 @@ class XGBModel(XGBModelBase):
|
||||
iteration_range=iteration_range
|
||||
)
|
||||
|
||||
def evals_result(self) -> TrainingCallback.EvalsLog:
|
||||
def evals_result(self) -> Dict[str, Dict[str, List[float]]]:
|
||||
"""Return the evaluation results.
|
||||
|
||||
If **eval_set** is passed to the `fit` function, you can call
|
||||
``evals_result()`` to get evaluation results for all passed **eval_sets**.
|
||||
When **eval_metric** is also passed to the `fit` function, the
|
||||
**evals_result** will contain the **eval_metrics** passed to the `fit` function.
|
||||
If **eval_set** is passed to the :py:meth:`fit` function, you can call
|
||||
``evals_result()`` to get evaluation results for all passed **eval_sets**. When
|
||||
**eval_metric** is also passed to the :py:meth:`fit` function, the
|
||||
**evals_result** will contain the **eval_metrics** passed to the :py:meth:`fit`
|
||||
function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
evals_result : dictionary
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
param_dist = {'objective':'binary:logistic', 'n_estimators':2}
|
||||
|
||||
clf = xgb.XGBModel(**param_dist)
|
||||
|
||||
clf.fit(X_train, y_train,
|
||||
eval_set=[(X_train, y_train), (X_test, y_test)],
|
||||
eval_metric='logloss',
|
||||
verbose=True)
|
||||
|
||||
evals_result = clf.evals_result()
|
||||
|
||||
The variable **evals_result** will contain:
|
||||
The returned evaluation result is a dictionary:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{'validation_0': {'logloss': ['0.604835', '0.531479']},
|
||||
'validation_1': {'logloss': ['0.41965', '0.17686']}}
|
||||
|
||||
Returns
|
||||
-------
|
||||
evals_result
|
||||
|
||||
"""
|
||||
if getattr(self, "evals_result_", None) is not None:
|
||||
evals_result = self.evals_result_
|
||||
@ -1193,8 +1177,8 @@ class XGBModel(XGBModelBase):
|
||||
.. note:: Intercept is defined only for linear learners
|
||||
|
||||
Intercept (bias) is only defined when the linear model is chosen as base
|
||||
learner (`booster=gblinear`). It is not defined for other base learner types, such
|
||||
as tree learners (`booster=gbtree`).
|
||||
learner (`booster=gblinear`). It is not defined for other base learner types,
|
||||
such as tree learners (`booster=gbtree`).
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -1251,15 +1235,15 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
*,
|
||||
sample_weight: Optional[array_like] = None,
|
||||
base_margin: Optional[array_like] = None,
|
||||
eval_set: Optional[List[Tuple[array_like, array_like]]] = None,
|
||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
||||
eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None,
|
||||
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
verbose: Optional[bool] = True,
|
||||
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
||||
sample_weight_eval_set: Optional[List[array_like]] = None,
|
||||
base_margin_eval_set: Optional[List[array_like]] = None,
|
||||
sample_weight_eval_set: Optional[Sequence[array_like]] = None,
|
||||
base_margin_eval_set: Optional[Sequence[array_like]] = None,
|
||||
feature_weights: Optional[array_like] = None,
|
||||
callbacks: Optional[List[TrainingCallback]] = None
|
||||
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||
) -> "XGBClassifier":
|
||||
# pylint: disable = attribute-defined-outside-init,too-many-statements
|
||||
evals_result: TrainingCallback.EvalsLog = {}
|
||||
@ -1445,51 +1429,6 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
getattr(self, "n_classes_", None), class_probs, np.vstack
|
||||
)
|
||||
|
||||
def evals_result(self) -> TrainingCallback.EvalsLog:
|
||||
"""Return the evaluation results.
|
||||
|
||||
If **eval_set** is passed to the `fit` function, you can call
|
||||
``evals_result()`` to get evaluation results for all passed **eval_sets**.
|
||||
|
||||
When **eval_metric** is also passed as a parameter, the **evals_result** will
|
||||
contain the **eval_metric** passed to the `fit` function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
evals_result : dictionary
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
param_dist = {
|
||||
'objective':'binary:logistic', 'n_estimators':2, eval_metric="logloss"
|
||||
}
|
||||
|
||||
clf = xgb.XGBClassifier(**param_dist)
|
||||
|
||||
clf.fit(X_train, y_train,
|
||||
eval_set=[(X_train, y_train), (X_test, y_test)],
|
||||
verbose=True)
|
||||
|
||||
evals_result = clf.evals_result()
|
||||
|
||||
The variable **evals_result** will contain
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{'validation_0': {'logloss': ['0.604835', '0.531479']},
|
||||
'validation_1': {'logloss': ['0.41965', '0.17686']}}
|
||||
|
||||
"""
|
||||
if self.evals_result_:
|
||||
evals_result = self.evals_result_
|
||||
else:
|
||||
raise XGBoostError('No results.')
|
||||
|
||||
return evals_result
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
"scikit-learn API for XGBoost random forest classification.",
|
||||
@ -1533,15 +1472,15 @@ class XGBRFClassifier(XGBClassifier):
|
||||
*,
|
||||
sample_weight: Optional[array_like] = None,
|
||||
base_margin: Optional[array_like] = None,
|
||||
eval_set: Optional[List[Tuple[array_like, array_like]]] = None,
|
||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
||||
eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None,
|
||||
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
verbose: Optional[bool] = True,
|
||||
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
||||
sample_weight_eval_set: Optional[List[array_like]] = None,
|
||||
base_margin_eval_set: Optional[List[array_like]] = None,
|
||||
sample_weight_eval_set: Optional[Sequence[array_like]] = None,
|
||||
base_margin_eval_set: Optional[Sequence[array_like]] = None,
|
||||
feature_weights: Optional[array_like] = None,
|
||||
callbacks: Optional[List[TrainingCallback]] = None
|
||||
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||
) -> "XGBRFClassifier":
|
||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||
_check_rf_callback(early_stopping_rounds, callbacks)
|
||||
@ -1605,15 +1544,15 @@ class XGBRFRegressor(XGBRegressor):
|
||||
*,
|
||||
sample_weight: Optional[array_like] = None,
|
||||
base_margin: Optional[array_like] = None,
|
||||
eval_set: Optional[List[Tuple[array_like, array_like]]] = None,
|
||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
||||
eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None,
|
||||
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
verbose: Optional[bool] = True,
|
||||
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
||||
sample_weight_eval_set: Optional[List[array_like]] = None,
|
||||
base_margin_eval_set: Optional[List[array_like]] = None,
|
||||
sample_weight_eval_set: Optional[Sequence[array_like]] = None,
|
||||
base_margin_eval_set: Optional[Sequence[array_like]] = None,
|
||||
feature_weights: Optional[array_like] = None,
|
||||
callbacks: Optional[List[TrainingCallback]] = None
|
||||
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||
) -> "XGBRFRegressor":
|
||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||
_check_rf_callback(early_stopping_rounds, callbacks)
|
||||
@ -1682,17 +1621,17 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||
qid: Optional[array_like] = None,
|
||||
sample_weight: Optional[array_like] = None,
|
||||
base_margin: Optional[array_like] = None,
|
||||
eval_set: Optional[List[Tuple[array_like, array_like]]] = None,
|
||||
eval_group: Optional[List[array_like]] = None,
|
||||
eval_qid: Optional[List[array_like]] = None,
|
||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
||||
eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None,
|
||||
eval_group: Optional[Sequence[array_like]] = None,
|
||||
eval_qid: Optional[Sequence[array_like]] = None,
|
||||
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
verbose: Optional[bool] = False,
|
||||
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
||||
sample_weight_eval_set: Optional[List[array_like]] = None,
|
||||
base_margin_eval_set: Optional[List[array_like]] = None,
|
||||
sample_weight_eval_set: Optional[Sequence[array_like]] = None,
|
||||
base_margin_eval_set: Optional[Sequence[array_like]] = None,
|
||||
feature_weights: Optional[array_like] = None,
|
||||
callbacks: Optional[List[TrainingCallback]] = None
|
||||
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||
) -> "XGBRanker":
|
||||
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||
"""Fit gradient boosting ranker
|
||||
|
||||
@ -13,7 +13,7 @@ from threading import Thread
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from typing import Dict, List, Tuple, Union, Optional
|
||||
from typing import Dict, List, Tuple, Union, Optional
|
||||
|
||||
_RingMap = Dict[int, Tuple[int, int]]
|
||||
_TreeMap = Dict[int, List[int]]
|
||||
|
||||
@ -3,18 +3,20 @@
|
||||
# pylint: disable=too-many-branches, too-many-statements
|
||||
"""Training Library containing training routines."""
|
||||
import copy
|
||||
from typing import Optional, List
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional, Dict, Any, Union, Tuple, cast, Sequence
|
||||
|
||||
import numpy as np
|
||||
from .core import Booster, XGBoostError, _get_booster_layer_trees
|
||||
from .core import _deprecate_positional_args
|
||||
from .core import Objective, Metric
|
||||
from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees
|
||||
from .core import Metric, Objective
|
||||
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
||||
from . import callback
|
||||
|
||||
|
||||
def _assert_new_callback(callbacks: Optional[List[callback.TrainingCallback]]) -> None:
|
||||
def _assert_new_callback(
|
||||
callbacks: Optional[Sequence[callback.TrainingCallback]]
|
||||
) -> None:
|
||||
is_new_callback: bool = not callbacks or all(
|
||||
isinstance(c, callback.TrainingCallback) for c in callbacks
|
||||
)
|
||||
@ -44,24 +46,24 @@ def _configure_custom_metric(
|
||||
|
||||
|
||||
def _train_internal(
|
||||
params,
|
||||
dtrain,
|
||||
num_boost_round=10,
|
||||
evals=(),
|
||||
obj=None,
|
||||
feval=None,
|
||||
custom_metric=None,
|
||||
xgb_model=None,
|
||||
callbacks=None,
|
||||
evals_result=None,
|
||||
maximize=None,
|
||||
verbose_eval=None,
|
||||
early_stopping_rounds=None,
|
||||
):
|
||||
params: Dict[str, Any],
|
||||
dtrain: DMatrix,
|
||||
num_boost_round: int = 10,
|
||||
evals: Optional[Sequence[Tuple[DMatrix, str]]] = None,
|
||||
obj: Optional[Objective] = None,
|
||||
feval: Optional[Metric] = None,
|
||||
custom_metric: Optional[Metric] = None,
|
||||
xgb_model: Optional[Union[str, os.PathLike, Booster, bytearray]] = None,
|
||||
callbacks: Optional[Sequence[callback.TrainingCallback]] = None,
|
||||
evals_result: callback.TrainingCallback.EvalsLog = None,
|
||||
maximize: Optional[bool] = None,
|
||||
verbose_eval: Optional[Union[bool, int]] = True,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
) -> Booster:
|
||||
"""internal training function"""
|
||||
callbacks = [] if callbacks is None else copy.copy(callbacks)
|
||||
callbacks = [] if callbacks is None else copy.copy(list(callbacks))
|
||||
metric_fn = _configure_custom_metric(feval, custom_metric)
|
||||
evals = list(evals)
|
||||
evals = list(evals) if evals else []
|
||||
|
||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||
|
||||
@ -78,7 +80,7 @@ def _train_internal(
|
||||
callbacks.append(
|
||||
callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
|
||||
)
|
||||
callbacks = callback.CallbackContainer(
|
||||
cb_container = callback.CallbackContainer(
|
||||
callbacks,
|
||||
metric=metric_fn,
|
||||
# For old `feval` parameter, the behavior is unchanged. For the new
|
||||
@ -87,32 +89,32 @@ def _train_internal(
|
||||
output_margin=callable(obj) or metric_fn is feval,
|
||||
)
|
||||
|
||||
bst = callbacks.before_training(bst)
|
||||
bst = cb_container.before_training(bst)
|
||||
|
||||
for i in range(start_iteration, num_boost_round):
|
||||
if callbacks.before_iteration(bst, i, dtrain, evals):
|
||||
if cb_container.before_iteration(bst, i, dtrain, evals):
|
||||
break
|
||||
bst.update(dtrain, i, obj)
|
||||
if callbacks.after_iteration(bst, i, dtrain, evals):
|
||||
if cb_container.after_iteration(bst, i, dtrain, evals):
|
||||
break
|
||||
|
||||
bst = callbacks.after_training(bst)
|
||||
bst = cb_container.after_training(bst)
|
||||
|
||||
if evals_result is not None:
|
||||
evals_result.update(callbacks.history)
|
||||
evals_result.update(cb_container.history)
|
||||
|
||||
# These should be moved into callback functions `after_training`, but until old
|
||||
# callbacks are removed, the train function is the only place for setting the
|
||||
# attributes.
|
||||
num_parallel_tree, _ = _get_booster_layer_trees(bst)
|
||||
if bst.attr('best_score') is not None:
|
||||
bst.best_score = float(bst.attr('best_score'))
|
||||
bst.best_iteration = int(bst.attr('best_iteration'))
|
||||
bst.best_score = float(cast(str, bst.attr('best_score')))
|
||||
bst.best_iteration = int(cast(str, bst.attr('best_iteration')))
|
||||
# num_class is handled internally
|
||||
bst.set_attr(
|
||||
best_ntree_limit=str((bst.best_iteration + 1) * num_parallel_tree)
|
||||
)
|
||||
bst.best_ntree_limit = int(bst.attr("best_ntree_limit"))
|
||||
bst.best_ntree_limit = int(cast(str, bst.attr("best_ntree_limit")))
|
||||
else:
|
||||
# Due to compatibility with version older than 1.4, these attributes are added
|
||||
# to Python object even if early stopping is not used.
|
||||
@ -126,35 +128,32 @@ def _train_internal(
|
||||
return bst.copy()
|
||||
|
||||
|
||||
@_deprecate_positional_args
|
||||
def train(
|
||||
params,
|
||||
dtrain,
|
||||
num_boost_round=10,
|
||||
*,
|
||||
evals=(),
|
||||
params: Dict[str, Any],
|
||||
dtrain: DMatrix,
|
||||
num_boost_round: int = 10,
|
||||
evals: Optional[Sequence[Tuple[DMatrix, str]]] = None,
|
||||
obj: Optional[Objective] = None,
|
||||
feval=None,
|
||||
maximize=None,
|
||||
early_stopping_rounds=None,
|
||||
evals_result=None,
|
||||
verbose_eval=True,
|
||||
xgb_model=None,
|
||||
callbacks=None,
|
||||
feval: Optional[Metric] = None,
|
||||
maximize: Optional[bool] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
evals_result: callback.TrainingCallback.EvalsLog = None,
|
||||
verbose_eval: Optional[Union[bool, int]] = True,
|
||||
xgb_model: Optional[Union[str, os.PathLike, Booster, bytearray]] = None,
|
||||
callbacks: Optional[Sequence[callback.TrainingCallback]] = None,
|
||||
custom_metric: Optional[Metric] = None,
|
||||
):
|
||||
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
|
||||
) -> Booster:
|
||||
"""Train a booster with given parameters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
params : dict
|
||||
params :
|
||||
Booster params.
|
||||
dtrain : DMatrix
|
||||
dtrain :
|
||||
Data to be trained.
|
||||
num_boost_round: int
|
||||
num_boost_round :
|
||||
Number of boosting iterations.
|
||||
evals: list of pairs (DMatrix, string)
|
||||
evals :
|
||||
List of validation sets for which metrics will evaluated during training.
|
||||
Validation metrics will help us track the performance of the model.
|
||||
obj
|
||||
@ -166,7 +165,7 @@ def train(
|
||||
Use `custom_metric` instead.
|
||||
maximize : bool
|
||||
Whether to maximize feval.
|
||||
early_stopping_rounds: int
|
||||
early_stopping_rounds :
|
||||
Activates early stopping. Validation metric needs to improve at least once in
|
||||
every **early_stopping_rounds** round(s) to continue training.
|
||||
Requires at least one item in **evals**.
|
||||
@ -178,7 +177,7 @@ def train(
|
||||
**params**, the last metric will be used for early stopping.
|
||||
If early stopping occurs, the model will have two additional fields:
|
||||
``bst.best_score``, ``bst.best_iteration``.
|
||||
evals_result: dict
|
||||
evals_result :
|
||||
This dictionary stores the evaluation results of all the items in watchlist.
|
||||
|
||||
Example: with a watchlist containing
|
||||
@ -191,7 +190,7 @@ def train(
|
||||
{'train': {'logloss': ['0.48253', '0.35953']},
|
||||
'eval': {'logloss': ['0.480385', '0.357756']}}
|
||||
|
||||
verbose_eval : bool or int
|
||||
verbose_eval :
|
||||
Requires at least one item in **evals**.
|
||||
If **verbose_eval** is True then the evaluation metric on the validation set is
|
||||
printed at each boosting stage.
|
||||
@ -200,9 +199,9 @@ def train(
|
||||
/ the boosting stage found by using **early_stopping_rounds** is also printed.
|
||||
Example: with ``verbose_eval=4`` and at least one item in **evals**, an evaluation metric
|
||||
is printed every 4 boosting stages, instead of every boosting stage.
|
||||
xgb_model : file name of stored xgb model or 'Booster' instance
|
||||
xgb_model :
|
||||
Xgb model to be loaded before training (allows training continuation).
|
||||
callbacks : list of callback functions
|
||||
callbacks :
|
||||
List of callback functions that are applied at end of each iteration.
|
||||
It is possible to use predefined callbacks by using
|
||||
:ref:`Callback API <callback_api>`.
|
||||
|
||||
@ -9,7 +9,9 @@ rng = np.random.RandomState(1994)
|
||||
|
||||
|
||||
class TestInteractionConstraints:
|
||||
def run_interaction_constraints(self, tree_method, feature_names=None, interaction_constraints='[[0, 1]]'):
|
||||
def run_interaction_constraints(
|
||||
self, tree_method, feature_names=None, interaction_constraints='[[0, 1]]'
|
||||
):
|
||||
x1 = np.random.normal(loc=1.0, scale=1.0, size=1000)
|
||||
x2 = np.random.normal(loc=1.0, scale=1.0, size=1000)
|
||||
x3 = np.random.choice([1, 2, 3], size=1000, replace=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user