Typehint for subset of core API. (#7348)
This commit is contained in:
parent
45aef75cca
commit
c6769488b3
@ -6,17 +6,18 @@ from abc import ABC
|
|||||||
import collections
|
import collections
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from typing import Callable, List, Optional, Union, Dict, Tuple
|
from typing import Callable, List, Optional, Union, Dict, Tuple, TypeVar, cast
|
||||||
|
from typing import Sequence
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from . import rabit
|
from . import rabit
|
||||||
from .core import Booster, XGBoostError
|
from .core import Booster, DMatrix, XGBoostError
|
||||||
from .compat import STRING_TYPES
|
from .compat import STRING_TYPES
|
||||||
|
|
||||||
|
|
||||||
# The new implementation of callback functions.
|
_Score = Union[float, Tuple[float, float]]
|
||||||
# Breaking:
|
_ScoreList = Union[List[float], List[Tuple[float, float]]]
|
||||||
# - reset learning rate no longer accepts total boosting rounds
|
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
class TrainingCallback(ABC):
|
class TrainingCallback(ABC):
|
||||||
@ -26,9 +27,9 @@ class TrainingCallback(ABC):
|
|||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
EvalsLog = Dict[str, Dict[str, Union[List[float], List[Tuple[float, float]]]]]
|
EvalsLog = Dict[str, Dict[str, _ScoreList]]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def before_training(self, model):
|
def before_training(self, model):
|
||||||
@ -48,18 +49,18 @@ class TrainingCallback(ABC):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _aggcv(rlist):
|
def _aggcv(rlist: List[str]) -> List[Tuple[str, float, float]]:
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name, too-many-locals
|
||||||
"""Aggregate cross-validation results.
|
"""Aggregate cross-validation results.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
cvmap = {}
|
cvmap: Dict[Tuple[int, str], List[float]] = {}
|
||||||
idx = rlist[0].split()[0]
|
idx = rlist[0].split()[0]
|
||||||
for line in rlist:
|
for line in rlist:
|
||||||
arr = line.split()
|
arr: List[str] = line.split()
|
||||||
assert idx == arr[0]
|
assert idx == arr[0]
|
||||||
for metric_idx, it in enumerate(arr[1:]):
|
for metric_idx, it in enumerate(arr[1:]):
|
||||||
if not isinstance(it, STRING_TYPES):
|
if not isinstance(it, str):
|
||||||
it = it.decode()
|
it = it.decode()
|
||||||
k, v = it.split(':')
|
k, v = it.split(':')
|
||||||
if (metric_idx, k) not in cvmap:
|
if (metric_idx, k) not in cvmap:
|
||||||
@ -67,16 +68,20 @@ def _aggcv(rlist):
|
|||||||
cvmap[(metric_idx, k)].append(float(v))
|
cvmap[(metric_idx, k)].append(float(v))
|
||||||
msg = idx
|
msg = idx
|
||||||
results = []
|
results = []
|
||||||
for (metric_idx, k), v in sorted(cvmap.items(), key=lambda x: x[0][0]):
|
for (_, name), s in sorted(cvmap.items(), key=lambda x: x[0][0]):
|
||||||
v = numpy.array(v)
|
as_arr = numpy.array(s)
|
||||||
if not isinstance(msg, STRING_TYPES):
|
if not isinstance(msg, STRING_TYPES):
|
||||||
msg = msg.decode()
|
msg = msg.decode()
|
||||||
mean, std = numpy.mean(v), numpy.std(v)
|
mean, std = numpy.mean(as_arr), numpy.std(as_arr)
|
||||||
results.extend([(k, mean, std)])
|
results.extend([(name, mean, std)])
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def _allreduce_metric(score):
|
# allreduce type
|
||||||
|
_ART = TypeVar("_ART")
|
||||||
|
|
||||||
|
|
||||||
|
def _allreduce_metric(score: _ART) -> _ART:
|
||||||
'''Helper function for computing customized metric in distributed
|
'''Helper function for computing customized metric in distributed
|
||||||
environment. Not strictly correct as many functions don't use mean value
|
environment. Not strictly correct as many functions don't use mean value
|
||||||
as final result.
|
as final result.
|
||||||
@ -89,13 +94,13 @@ def _allreduce_metric(score):
|
|||||||
if isinstance(score, tuple): # has mean and stdv
|
if isinstance(score, tuple): # has mean and stdv
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'xgboost.cv function should not be used in distributed environment.')
|
'xgboost.cv function should not be used in distributed environment.')
|
||||||
score = numpy.array([score])
|
arr = numpy.array([score])
|
||||||
score = rabit.allreduce(score, rabit.Op.SUM) / world
|
arr = rabit.allreduce(arr, rabit.Op.SUM) / world
|
||||||
return score[0]
|
return arr[0]
|
||||||
|
|
||||||
|
|
||||||
class CallbackContainer:
|
class CallbackContainer:
|
||||||
'''A special callback for invoking a list of other callbacks.
|
'''A special internal callback for invoking a list of other callbacks.
|
||||||
|
|
||||||
.. versionadded:: 1.3.0
|
.. versionadded:: 1.3.0
|
||||||
|
|
||||||
@ -105,7 +110,7 @@ class CallbackContainer:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
callbacks: List[TrainingCallback],
|
callbacks: Sequence[TrainingCallback],
|
||||||
metric: Callable = None,
|
metric: Callable = None,
|
||||||
output_margin: bool = True,
|
output_margin: bool = True,
|
||||||
is_cv: bool = False
|
is_cv: bool = False
|
||||||
@ -146,33 +151,50 @@ class CallbackContainer:
|
|||||||
assert isinstance(model, Booster), msg
|
assert isinstance(model, Booster), msg
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def before_iteration(self, model, epoch, dtrain, evals) -> bool:
|
def before_iteration(
|
||||||
|
self, model, epoch: int, dtrain: DMatrix, evals: List[Tuple[DMatrix, str]]
|
||||||
|
) -> bool:
|
||||||
'''Function called before training iteration.'''
|
'''Function called before training iteration.'''
|
||||||
return any(c.before_iteration(model, epoch, self.history)
|
return any(c.before_iteration(model, epoch, self.history)
|
||||||
for c in self.callbacks)
|
for c in self.callbacks)
|
||||||
|
|
||||||
def _update_history(self, score, epoch):
|
def _update_history(
|
||||||
|
self,
|
||||||
|
score: Union[List[Tuple[str, float]], List[Tuple[str, float, float]]],
|
||||||
|
epoch: int
|
||||||
|
) -> None:
|
||||||
for d in score:
|
for d in score:
|
||||||
name, s = d[0], float(d[1])
|
name: str = d[0]
|
||||||
|
s: float = d[1]
|
||||||
if self.is_cv:
|
if self.is_cv:
|
||||||
std = float(d[2])
|
std = float(cast(Tuple[str, float, float], d)[2])
|
||||||
s = (s, std)
|
x: _Score = (s, std)
|
||||||
|
else:
|
||||||
|
x = s
|
||||||
splited_names = name.split('-')
|
splited_names = name.split('-')
|
||||||
data_name = splited_names[0]
|
data_name = splited_names[0]
|
||||||
metric_name = '-'.join(splited_names[1:])
|
metric_name = '-'.join(splited_names[1:])
|
||||||
s = _allreduce_metric(s)
|
x = _allreduce_metric(x)
|
||||||
if data_name in self.history:
|
if data_name not in self.history:
|
||||||
data_history = self.history[data_name]
|
|
||||||
if metric_name in data_history:
|
|
||||||
data_history[metric_name].append(s)
|
|
||||||
else:
|
|
||||||
data_history[metric_name] = [s]
|
|
||||||
else:
|
|
||||||
self.history[data_name] = collections.OrderedDict()
|
self.history[data_name] = collections.OrderedDict()
|
||||||
self.history[data_name][metric_name] = [s]
|
data_history = self.history[data_name]
|
||||||
return False
|
if metric_name not in data_history:
|
||||||
|
data_history[metric_name] = cast(_ScoreList, [])
|
||||||
|
metric_history = data_history[metric_name]
|
||||||
|
if self.is_cv:
|
||||||
|
cast(List[Tuple[float, float]], metric_history).append(
|
||||||
|
cast(Tuple[float, float], x)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cast(List[float], metric_history).append(cast(float, x))
|
||||||
|
|
||||||
def after_iteration(self, model, epoch, dtrain, evals) -> bool:
|
def after_iteration(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
epoch: int,
|
||||||
|
dtrain: DMatrix,
|
||||||
|
evals: Optional[List[Tuple[DMatrix, str]]],
|
||||||
|
) -> bool:
|
||||||
'''Function called after training iteration.'''
|
'''Function called after training iteration.'''
|
||||||
if self.is_cv:
|
if self.is_cv:
|
||||||
scores = model.eval(epoch, self.metric, self._output_margin)
|
scores = model.eval(epoch, self.metric, self._output_margin)
|
||||||
@ -183,18 +205,20 @@ class CallbackContainer:
|
|||||||
evals = [] if evals is None else evals
|
evals = [] if evals is None else evals
|
||||||
for _, name in evals:
|
for _, name in evals:
|
||||||
assert name.find('-') == -1, 'Dataset name should not contain `-`'
|
assert name.find('-') == -1, 'Dataset name should not contain `-`'
|
||||||
score = model.eval_set(evals, epoch, self.metric, self._output_margin)
|
score: str = model.eval_set(evals, epoch, self.metric, self._output_margin)
|
||||||
score = score.split()[1:] # into datasets
|
splited = score.split()[1:] # into datasets
|
||||||
# split up `test-error:0.1234`
|
# split up `test-error:0.1234`
|
||||||
score = [tuple(s.split(':')) for s in score]
|
metric_score_str = [tuple(s.split(':')) for s in splited]
|
||||||
self._update_history(score, epoch)
|
# convert to float
|
||||||
|
metric_score = [(n, float(s)) for n, s in metric_score_str]
|
||||||
|
self._update_history(metric_score, epoch)
|
||||||
ret = any(c.after_iteration(model, epoch, self.history)
|
ret = any(c.after_iteration(model, epoch, self.history)
|
||||||
for c in self.callbacks)
|
for c in self.callbacks)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class LearningRateScheduler(TrainingCallback):
|
class LearningRateScheduler(TrainingCallback):
|
||||||
'''Callback function for scheduling learning rate.
|
"""Callback function for scheduling learning rate.
|
||||||
|
|
||||||
.. versionadded:: 1.3.0
|
.. versionadded:: 1.3.0
|
||||||
|
|
||||||
@ -207,18 +231,24 @@ class LearningRateScheduler(TrainingCallback):
|
|||||||
should be a sequence like list or tuple with the same size of boosting
|
should be a sequence like list or tuple with the same size of boosting
|
||||||
rounds.
|
rounds.
|
||||||
|
|
||||||
'''
|
"""
|
||||||
def __init__(self, learning_rates) -> None:
|
|
||||||
assert callable(learning_rates) or \
|
def __init__(
|
||||||
isinstance(learning_rates, collections.abc.Sequence)
|
self, learning_rates: Union[Callable[[int], float], Sequence[float]]
|
||||||
|
) -> None:
|
||||||
|
assert callable(learning_rates) or isinstance(
|
||||||
|
learning_rates, collections.abc.Sequence
|
||||||
|
)
|
||||||
if callable(learning_rates):
|
if callable(learning_rates):
|
||||||
self.learning_rates = learning_rates
|
self.learning_rates = learning_rates
|
||||||
else:
|
else:
|
||||||
self.learning_rates = lambda epoch: learning_rates[epoch]
|
self.learning_rates = lambda epoch: cast(Sequence, learning_rates)[epoch]
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def after_iteration(self, model, epoch, evals_log) -> bool:
|
def after_iteration(
|
||||||
model.set_param('learning_rate', self.learning_rates(epoch))
|
self, model, epoch: int, evals_log: TrainingCallback.EvalsLog
|
||||||
|
) -> bool:
|
||||||
|
model.set_param("learning_rate", self.learning_rates(epoch))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@ -230,17 +260,17 @@ class EarlyStopping(TrainingCallback):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
rounds
|
rounds :
|
||||||
Early stopping rounds.
|
Early stopping rounds.
|
||||||
metric_name
|
metric_name :
|
||||||
Name of metric that is used for early stopping.
|
Name of metric that is used for early stopping.
|
||||||
data_name
|
data_name :
|
||||||
Name of dataset that is used for early stopping.
|
Name of dataset that is used for early stopping.
|
||||||
maximize
|
maximize :
|
||||||
Whether to maximize evaluation metric. None means auto (discouraged).
|
Whether to maximize evaluation metric. None means auto (discouraged).
|
||||||
save_best
|
save_best :
|
||||||
Whether training should return the best model or the last model.
|
Whether training should return the best model or the last model.
|
||||||
min_delta
|
min_delta :
|
||||||
Minimum absolute change in score to be qualified as an improvement.
|
Minimum absolute change in score to be qualified as an improvement.
|
||||||
|
|
||||||
.. versionadded:: 1.5.0
|
.. versionadded:: 1.5.0
|
||||||
@ -279,8 +309,6 @@ class EarlyStopping(TrainingCallback):
|
|||||||
if self._min_delta < 0:
|
if self._min_delta < 0:
|
||||||
raise ValueError("min_delta must be greater or equal to 0.")
|
raise ValueError("min_delta must be greater or equal to 0.")
|
||||||
|
|
||||||
self.improve_op = None
|
|
||||||
|
|
||||||
self.current_rounds: int = 0
|
self.current_rounds: int = 0
|
||||||
self.best_scores: dict = {}
|
self.best_scores: dict = {}
|
||||||
self.starting_round: int = 0
|
self.starting_round: int = 0
|
||||||
@ -290,16 +318,18 @@ class EarlyStopping(TrainingCallback):
|
|||||||
self.starting_round = model.num_boosted_rounds()
|
self.starting_round = model.num_boosted_rounds()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _update_rounds(self, score, name, metric, model, epoch) -> bool:
|
def _update_rounds(
|
||||||
def get_s(x):
|
self, score: _Score, name: str, metric: str, model, epoch: int
|
||||||
|
) -> bool:
|
||||||
|
def get_s(x: _Score) -> float:
|
||||||
"""get score if it's cross validation history."""
|
"""get score if it's cross validation history."""
|
||||||
return x[0] if isinstance(x, tuple) else x
|
return x[0] if isinstance(x, tuple) else x
|
||||||
|
|
||||||
def maximize(new, best):
|
def maximize(new: _Score, best: _Score) -> bool:
|
||||||
"""New score should be greater than the old one."""
|
"""New score should be greater than the old one."""
|
||||||
return numpy.greater(get_s(new) - self._min_delta, get_s(best))
|
return numpy.greater(get_s(new) - self._min_delta, get_s(best))
|
||||||
|
|
||||||
def minimize(new, best):
|
def minimize(new: _Score, best: _Score) -> bool:
|
||||||
"""New score should be smaller than the old one."""
|
"""New score should be smaller than the old one."""
|
||||||
return numpy.greater(get_s(best) - self._min_delta, get_s(new))
|
return numpy.greater(get_s(best) - self._min_delta, get_s(new))
|
||||||
|
|
||||||
@ -314,25 +344,25 @@ class EarlyStopping(TrainingCallback):
|
|||||||
self.maximize = False
|
self.maximize = False
|
||||||
|
|
||||||
if self.maximize:
|
if self.maximize:
|
||||||
self.improve_op = maximize
|
improve_op = maximize
|
||||||
else:
|
else:
|
||||||
self.improve_op = minimize
|
improve_op = minimize
|
||||||
|
|
||||||
assert self.improve_op
|
assert improve_op
|
||||||
|
|
||||||
if not self.stopping_history: # First round
|
if not self.stopping_history: # First round
|
||||||
self.current_rounds = 0
|
self.current_rounds = 0
|
||||||
self.stopping_history[name] = {}
|
self.stopping_history[name] = {}
|
||||||
self.stopping_history[name][metric] = [score]
|
self.stopping_history[name][metric] = cast(_ScoreList, [score])
|
||||||
self.best_scores[name] = {}
|
self.best_scores[name] = {}
|
||||||
self.best_scores[name][metric] = [score]
|
self.best_scores[name][metric] = [score]
|
||||||
model.set_attr(best_score=str(score), best_iteration=str(epoch))
|
model.set_attr(best_score=str(score), best_iteration=str(epoch))
|
||||||
elif not self.improve_op(score, self.best_scores[name][metric][-1]):
|
elif not improve_op(score, self.best_scores[name][metric][-1]):
|
||||||
# Not improved
|
# Not improved
|
||||||
self.stopping_history[name][metric].append(score)
|
self.stopping_history[name][metric].append(score) # type: ignore
|
||||||
self.current_rounds += 1
|
self.current_rounds += 1
|
||||||
else: # Improved
|
else: # Improved
|
||||||
self.stopping_history[name][metric].append(score)
|
self.stopping_history[name][metric].append(score) # type: ignore
|
||||||
self.best_scores[name][metric].append(score)
|
self.best_scores[name][metric].append(score)
|
||||||
record = self.stopping_history[name][metric][-1]
|
record = self.stopping_history[name][metric][-1]
|
||||||
model.set_attr(best_score=str(record), best_iteration=str(epoch))
|
model.set_attr(best_score=str(record), best_iteration=str(epoch))
|
||||||
@ -390,16 +420,16 @@ class EvaluationMonitor(TrainingCallback):
|
|||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
|
||||||
metric : callable
|
metric :
|
||||||
Extra user defined metric.
|
Extra user defined metric.
|
||||||
rank : int
|
rank :
|
||||||
Which worker should be used for printing the result.
|
Which worker should be used for printing the result.
|
||||||
period : int
|
period :
|
||||||
How many epoches between printing.
|
How many epoches between printing.
|
||||||
show_stdv : bool
|
show_stdv :
|
||||||
Used in cv to show standard deviation. Users should not specify it.
|
Used in cv to show standard deviation. Users should not specify it.
|
||||||
'''
|
'''
|
||||||
def __init__(self, rank=0, period=1, show_stdv=False) -> None:
|
def __init__(self, rank: int = 0, period: int = 1, show_stdv: bool = False) -> None:
|
||||||
self.printer_rank = rank
|
self.printer_rank = rank
|
||||||
self.show_stdv = show_stdv
|
self.show_stdv = show_stdv
|
||||||
self.period = period
|
self.period = period
|
||||||
@ -457,22 +487,27 @@ class TrainingCheckPoint(TrainingCallback):
|
|||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
|
||||||
directory : os.PathLike
|
directory :
|
||||||
Output model directory.
|
Output model directory.
|
||||||
name : str
|
name :
|
||||||
pattern of output model file. Models will be saved as name_0.json, name_1.json,
|
pattern of output model file. Models will be saved as name_0.json, name_1.json,
|
||||||
name_2.json ....
|
name_2.json ....
|
||||||
as_pickle : boolean
|
as_pickle :
|
||||||
When set to Ture, all training parameters will be saved in pickle format, instead
|
When set to Ture, all training parameters will be saved in pickle format, instead
|
||||||
of saving only the model.
|
of saving only the model.
|
||||||
iterations : int
|
iterations :
|
||||||
Interval of checkpointing. Checkpointing is slow so setting a larger number can
|
Interval of checkpointing. Checkpointing is slow so setting a larger number can
|
||||||
reduce performance hit.
|
reduce performance hit.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def __init__(self, directory: os.PathLike, name: str = 'model',
|
def __init__(
|
||||||
as_pickle=False, iterations: int = 100):
|
self,
|
||||||
self._path = directory
|
directory: Union[str, os.PathLike],
|
||||||
|
name: str = 'model',
|
||||||
|
as_pickle: bool = False,
|
||||||
|
iterations: int = 100
|
||||||
|
) -> None:
|
||||||
|
self._path = os.fspath(directory)
|
||||||
self._name = name
|
self._name = name
|
||||||
self._as_pickle = as_pickle
|
self._as_pickle = as_pickle
|
||||||
self._iterations = iterations
|
self._iterations = iterations
|
||||||
|
|||||||
@ -6,7 +6,7 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import List, Optional, Any, Union, Dict, TypeVar
|
from typing import List, Optional, Any, Union, Dict, TypeVar
|
||||||
# pylint: enable=no-name-in-module,import-error
|
# pylint: enable=no-name-in-module,import-error
|
||||||
from typing import Callable, Tuple, cast
|
from typing import Callable, Tuple, cast, Sequence
|
||||||
import ctypes
|
import ctypes
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@ -31,20 +31,6 @@ class XGBoostError(ValueError):
|
|||||||
"""Error thrown by xgboost trainer."""
|
"""Error thrown by xgboost trainer."""
|
||||||
|
|
||||||
|
|
||||||
class EarlyStopException(Exception):
|
|
||||||
"""Exception to signal early stopping.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
best_iteration : int
|
|
||||||
The best iteration stopped.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, best_iteration):
|
|
||||||
super().__init__()
|
|
||||||
self.best_iteration = best_iteration
|
|
||||||
|
|
||||||
|
|
||||||
def from_pystr_to_cstr(data: Union[str, List[str]]):
|
def from_pystr_to_cstr(data: Union[str, List[str]]):
|
||||||
"""Convert a Python str or list of Python str to C pointer
|
"""Convert a Python str or list of Python str to C pointer
|
||||||
|
|
||||||
@ -132,18 +118,19 @@ def _log_callback(msg: bytes) -> None:
|
|||||||
print(py_str(msg))
|
print(py_str(msg))
|
||||||
|
|
||||||
|
|
||||||
def _get_log_callback_func():
|
def _get_log_callback_func() -> Callable:
|
||||||
"""Wrap log_callback() method in ctypes callback type"""
|
"""Wrap log_callback() method in ctypes callback type"""
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_char_p)
|
CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_char_p)
|
||||||
return CALLBACK(_log_callback)
|
return CALLBACK(_log_callback)
|
||||||
|
|
||||||
|
|
||||||
def _load_lib():
|
def _load_lib() -> ctypes.CDLL:
|
||||||
"""Load xgboost Library."""
|
"""Load xgboost Library."""
|
||||||
lib_paths = find_lib_path()
|
lib_paths = find_lib_path()
|
||||||
if not lib_paths:
|
if not lib_paths:
|
||||||
return None
|
# This happens only when building document.
|
||||||
|
return None # type: ignore
|
||||||
try:
|
try:
|
||||||
pathBackup = os.environ['PATH'].split(os.pathsep)
|
pathBackup = os.environ['PATH'].split(os.pathsep)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -190,7 +177,7 @@ Error message(s): {os_error_list}
|
|||||||
_LIB = _load_lib()
|
_LIB = _load_lib()
|
||||||
|
|
||||||
|
|
||||||
def _check_call(ret):
|
def _check_call(ret: int) -> None:
|
||||||
"""Check the return value of C API call
|
"""Check the return value of C API call
|
||||||
|
|
||||||
This function will raise exception when error occurs.
|
This function will raise exception when error occurs.
|
||||||
@ -234,7 +221,7 @@ def _cuda_array_interface(data) -> bytes:
|
|||||||
return interface_str
|
return interface_str
|
||||||
|
|
||||||
|
|
||||||
def ctypes2numpy(cptr, length, dtype):
|
def ctypes2numpy(cptr, length, dtype) -> np.ndarray:
|
||||||
"""Convert a ctypes pointer array to a numpy array."""
|
"""Convert a ctypes pointer array to a numpy array."""
|
||||||
ctype = _numpy2ctypes_type(dtype)
|
ctype = _numpy2ctypes_type(dtype)
|
||||||
if not isinstance(cptr, ctypes.POINTER(ctype)):
|
if not isinstance(cptr, ctypes.POINTER(ctype)):
|
||||||
@ -271,7 +258,7 @@ def ctypes2cupy(cptr, length, dtype):
|
|||||||
return arr
|
return arr
|
||||||
|
|
||||||
|
|
||||||
def ctypes2buffer(cptr, length):
|
def ctypes2buffer(cptr, length) -> bytearray:
|
||||||
"""Convert ctypes pointer to buffer type."""
|
"""Convert ctypes pointer to buffer type."""
|
||||||
if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)):
|
if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)):
|
||||||
raise RuntimeError('expected char pointer')
|
raise RuntimeError('expected char pointer')
|
||||||
@ -428,7 +415,7 @@ class DataIter: # pylint: disable=too-many-instance-attributes
|
|||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
|
||||||
data_handle:
|
input_data:
|
||||||
A function with same data fields like `data`, `label` with
|
A function with same data fields like `data`, `label` with
|
||||||
`xgboost.DMatrix`.
|
`xgboost.DMatrix`.
|
||||||
|
|
||||||
@ -627,7 +614,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
if feature_types is not None:
|
if feature_types is not None:
|
||||||
self.feature_types = feature_types
|
self.feature_types = feature_types
|
||||||
|
|
||||||
def _init_from_iter(self, iterator: DataIter, enable_categorical: bool):
|
def _init_from_iter(self, iterator: DataIter, enable_categorical: bool) -> None:
|
||||||
it = iterator
|
it = iterator
|
||||||
args = {
|
args = {
|
||||||
"missing": self.missing,
|
"missing": self.missing,
|
||||||
@ -654,7 +641,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
_check_call(ret)
|
_check_call(ret)
|
||||||
self.handle = handle
|
self.handle = handle
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self) -> None:
|
||||||
if hasattr(self, "handle") and self.handle:
|
if hasattr(self, "handle") and self.handle:
|
||||||
_check_call(_LIB.XGDMatrixFree(self.handle))
|
_check_call(_LIB.XGDMatrixFree(self.handle))
|
||||||
self.handle = None
|
self.handle = None
|
||||||
@ -699,7 +686,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
dispatch_meta_backend(matrix=self, data=feature_weights,
|
dispatch_meta_backend(matrix=self, data=feature_weights,
|
||||||
name='feature_weights')
|
name='feature_weights')
|
||||||
|
|
||||||
def get_float_info(self, field):
|
def get_float_info(self, field: str) -> np.ndarray:
|
||||||
"""Get float property from the DMatrix.
|
"""Get float property from the DMatrix.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -720,7 +707,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
ctypes.byref(ret)))
|
ctypes.byref(ret)))
|
||||||
return ctypes2numpy(ret, length.value, np.float32)
|
return ctypes2numpy(ret, length.value, np.float32)
|
||||||
|
|
||||||
def get_uint_info(self, field):
|
def get_uint_info(self, field: str) -> np.ndarray:
|
||||||
"""Get unsigned integer property from the DMatrix.
|
"""Get unsigned integer property from the DMatrix.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -741,7 +728,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
ctypes.byref(ret)))
|
ctypes.byref(ret)))
|
||||||
return ctypes2numpy(ret, length.value, np.uint32)
|
return ctypes2numpy(ret, length.value, np.uint32)
|
||||||
|
|
||||||
def set_float_info(self, field, data):
|
def set_float_info(self, field: str, data) -> None:
|
||||||
"""Set float type property into the DMatrix.
|
"""Set float type property into the DMatrix.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -755,7 +742,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
from .data import dispatch_meta_backend
|
from .data import dispatch_meta_backend
|
||||||
dispatch_meta_backend(self, data, field, 'float')
|
dispatch_meta_backend(self, data, field, 'float')
|
||||||
|
|
||||||
def set_float_info_npy2d(self, field, data):
|
def set_float_info_npy2d(self, field: str, data) -> None:
|
||||||
"""Set float type property into the DMatrix
|
"""Set float type property into the DMatrix
|
||||||
for numpy 2d array input
|
for numpy 2d array input
|
||||||
|
|
||||||
@ -770,7 +757,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
from .data import dispatch_meta_backend
|
from .data import dispatch_meta_backend
|
||||||
dispatch_meta_backend(self, data, field, 'float')
|
dispatch_meta_backend(self, data, field, 'float')
|
||||||
|
|
||||||
def set_uint_info(self, field, data):
|
def set_uint_info(self, field: str, data) -> None:
|
||||||
"""Set uint type property into the DMatrix.
|
"""Set uint type property into the DMatrix.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -784,7 +771,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
from .data import dispatch_meta_backend
|
from .data import dispatch_meta_backend
|
||||||
dispatch_meta_backend(self, data, field, 'uint32')
|
dispatch_meta_backend(self, data, field, 'uint32')
|
||||||
|
|
||||||
def save_binary(self, fname, silent=True):
|
def save_binary(self, fname, silent=True) -> None:
|
||||||
"""Save DMatrix to an XGBoost buffer. Saved binary can be later loaded
|
"""Save DMatrix to an XGBoost buffer. Saved binary can be later loaded
|
||||||
by providing the path to :py:func:`xgboost.DMatrix` as input.
|
by providing the path to :py:func:`xgboost.DMatrix` as input.
|
||||||
|
|
||||||
@ -800,7 +787,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
c_str(fname),
|
c_str(fname),
|
||||||
ctypes.c_int(silent)))
|
ctypes.c_int(silent)))
|
||||||
|
|
||||||
def set_label(self, label):
|
def set_label(self, label) -> None:
|
||||||
"""Set label of dmatrix
|
"""Set label of dmatrix
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -811,7 +798,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
from .data import dispatch_meta_backend
|
from .data import dispatch_meta_backend
|
||||||
dispatch_meta_backend(self, label, 'label', 'float')
|
dispatch_meta_backend(self, label, 'label', 'float')
|
||||||
|
|
||||||
def set_weight(self, weight):
|
def set_weight(self, weight) -> None:
|
||||||
"""Set weight of each instance.
|
"""Set weight of each instance.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -830,7 +817,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
from .data import dispatch_meta_backend
|
from .data import dispatch_meta_backend
|
||||||
dispatch_meta_backend(self, weight, 'weight', 'float')
|
dispatch_meta_backend(self, weight, 'weight', 'float')
|
||||||
|
|
||||||
def set_base_margin(self, margin):
|
def set_base_margin(self, margin) -> None:
|
||||||
"""Set base margin of booster to start from.
|
"""Set base margin of booster to start from.
|
||||||
|
|
||||||
This can be used to specify a prediction value of existing model to be
|
This can be used to specify a prediction value of existing model to be
|
||||||
@ -847,7 +834,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
from .data import dispatch_meta_backend
|
from .data import dispatch_meta_backend
|
||||||
dispatch_meta_backend(self, margin, 'base_margin', 'float')
|
dispatch_meta_backend(self, margin, 'base_margin', 'float')
|
||||||
|
|
||||||
def set_group(self, group):
|
def set_group(self, group) -> None:
|
||||||
"""Set group size of DMatrix (used for ranking).
|
"""Set group size of DMatrix (used for ranking).
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -858,7 +845,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
from .data import dispatch_meta_backend
|
from .data import dispatch_meta_backend
|
||||||
dispatch_meta_backend(self, group, 'group', 'uint32')
|
dispatch_meta_backend(self, group, 'group', 'uint32')
|
||||||
|
|
||||||
def get_label(self):
|
def get_label(self) -> np.ndarray:
|
||||||
"""Get the label of the DMatrix.
|
"""Get the label of the DMatrix.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -867,7 +854,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
"""
|
"""
|
||||||
return self.get_float_info('label')
|
return self.get_float_info('label')
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self) -> np.ndarray:
|
||||||
"""Get the weight of the DMatrix.
|
"""Get the weight of the DMatrix.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -876,7 +863,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
"""
|
"""
|
||||||
return self.get_float_info('weight')
|
return self.get_float_info('weight')
|
||||||
|
|
||||||
def get_base_margin(self):
|
def get_base_margin(self) -> np.ndarray:
|
||||||
"""Get the base margin of the DMatrix.
|
"""Get the base margin of the DMatrix.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -885,7 +872,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
"""
|
"""
|
||||||
return self.get_float_info('base_margin')
|
return self.get_float_info('base_margin')
|
||||||
|
|
||||||
def num_row(self):
|
def num_row(self) -> int:
|
||||||
"""Get the number of rows in the DMatrix.
|
"""Get the number of rows in the DMatrix.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -897,7 +884,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
ctypes.byref(ret)))
|
ctypes.byref(ret)))
|
||||||
return ret.value
|
return ret.value
|
||||||
|
|
||||||
def num_col(self):
|
def num_col(self) -> int:
|
||||||
"""Get the number of columns (features) in the DMatrix.
|
"""Get the number of columns (features) in the DMatrix.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -1191,7 +1178,7 @@ class DeviceQuantileDMatrix(DMatrix):
|
|||||||
enable_categorical=enable_categorical,
|
enable_categorical=enable_categorical,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init(self, data, enable_categorical, **meta):
|
def _init(self, data, enable_categorical: bool, **meta) -> None:
|
||||||
from .data import (
|
from .data import (
|
||||||
_is_dlpack,
|
_is_dlpack,
|
||||||
_transform_dlpack,
|
_transform_dlpack,
|
||||||
@ -1265,7 +1252,7 @@ def _get_booster_layer_trees(model: "Booster") -> Tuple[int, int]:
|
|||||||
return num_parallel_tree, num_groups
|
return num_parallel_tree, num_groups
|
||||||
|
|
||||||
|
|
||||||
class Booster(object):
|
class Booster:
|
||||||
# pylint: disable=too-many-public-methods
|
# pylint: disable=too-many-public-methods
|
||||||
"""A Booster of XGBoost.
|
"""A Booster of XGBoost.
|
||||||
|
|
||||||
@ -1273,7 +1260,12 @@ class Booster(object):
|
|||||||
training, prediction and evaluation.
|
training, prediction and evaluation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params=None, cache=(), model_file=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: Optional[Dict] = None,
|
||||||
|
cache: Optional[Sequence[DMatrix]] = None,
|
||||||
|
model_file: Optional[Union["Booster", bytearray, os.PathLike, str]] = None
|
||||||
|
) -> None:
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
@ -1285,12 +1277,13 @@ class Booster(object):
|
|||||||
model_file : string/os.PathLike/Booster/bytearray
|
model_file : string/os.PathLike/Booster/bytearray
|
||||||
Path to the model file if it's string or PathLike.
|
Path to the model file if it's string or PathLike.
|
||||||
"""
|
"""
|
||||||
|
cache = cache if cache is not None else []
|
||||||
for d in cache:
|
for d in cache:
|
||||||
if not isinstance(d, DMatrix):
|
if not isinstance(d, DMatrix):
|
||||||
raise TypeError(f'invalid cache item: {type(d).__name__}', cache)
|
raise TypeError(f'invalid cache item: {type(d).__name__}', cache)
|
||||||
|
|
||||||
dmats = c_array(ctypes.c_void_p, [d.handle for d in cache])
|
dmats = c_array(ctypes.c_void_p, [d.handle for d in cache])
|
||||||
self.handle = ctypes.c_void_p()
|
self.handle: Optional[ctypes.c_void_p] = ctypes.c_void_p()
|
||||||
_check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(len(cache)),
|
_check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(len(cache)),
|
||||||
ctypes.byref(self.handle)))
|
ctypes.byref(self.handle)))
|
||||||
for d in cache:
|
for d in cache:
|
||||||
@ -1405,12 +1398,12 @@ class Booster(object):
|
|||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self) -> None:
|
||||||
if hasattr(self, 'handle') and self.handle is not None:
|
if hasattr(self, 'handle') and self.handle is not None:
|
||||||
_check_call(_LIB.XGBoosterFree(self.handle))
|
_check_call(_LIB.XGBoosterFree(self.handle))
|
||||||
self.handle = None
|
self.handle = None
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self) -> Dict:
|
||||||
# can't pickle ctypes pointers, put model content in bytearray
|
# can't pickle ctypes pointers, put model content in bytearray
|
||||||
this = self.__dict__.copy()
|
this = self.__dict__.copy()
|
||||||
handle = this['handle']
|
handle = this['handle']
|
||||||
@ -1424,7 +1417,7 @@ class Booster(object):
|
|||||||
this["handle"] = buf
|
this["handle"] = buf
|
||||||
return this
|
return this
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state: Dict) -> None:
|
||||||
# reconstruct handle from raw data
|
# reconstruct handle from raw data
|
||||||
handle = state['handle']
|
handle = state['handle']
|
||||||
if handle is not None:
|
if handle is not None:
|
||||||
@ -1440,7 +1433,7 @@ class Booster(object):
|
|||||||
state['handle'] = handle
|
state['handle'] = handle
|
||||||
self.__dict__.update(state)
|
self.__dict__.update(state)
|
||||||
|
|
||||||
def __getitem__(self, val):
|
def __getitem__(self, val: Union[int, tuple, slice]) -> "Booster":
|
||||||
if isinstance(val, int):
|
if isinstance(val, int):
|
||||||
val = slice(val, val+1)
|
val = slice(val, val+1)
|
||||||
if isinstance(val, tuple):
|
if isinstance(val, tuple):
|
||||||
@ -1461,13 +1454,14 @@ class Booster(object):
|
|||||||
|
|
||||||
step = val.step if val.step is not None else 1
|
step = val.step if val.step is not None else 1
|
||||||
|
|
||||||
start = ctypes.c_int(start)
|
c_start = ctypes.c_int(start)
|
||||||
stop = ctypes.c_int(stop)
|
c_stop = ctypes.c_int(stop)
|
||||||
step = ctypes.c_int(step)
|
c_step = ctypes.c_int(step)
|
||||||
|
|
||||||
sliced_handle = ctypes.c_void_p()
|
sliced_handle = ctypes.c_void_p()
|
||||||
status = _LIB.XGBoosterSlice(self.handle, start, stop, step,
|
status = _LIB.XGBoosterSlice(
|
||||||
ctypes.byref(sliced_handle))
|
self.handle, c_start, c_stop, c_step, ctypes.byref(sliced_handle)
|
||||||
|
)
|
||||||
if status == -2:
|
if status == -2:
|
||||||
raise IndexError('Layer index out of range')
|
raise IndexError('Layer index out of range')
|
||||||
_check_call(status)
|
_check_call(status)
|
||||||
@ -1477,7 +1471,7 @@ class Booster(object):
|
|||||||
sliced.handle = sliced_handle
|
sliced.handle = sliced_handle
|
||||||
return sliced
|
return sliced
|
||||||
|
|
||||||
def save_config(self):
|
def save_config(self) -> str:
|
||||||
'''Output internal parameter configuration of Booster as a JSON
|
'''Output internal parameter configuration of Booster as a JSON
|
||||||
string.
|
string.
|
||||||
|
|
||||||
@ -1489,10 +1483,11 @@ class Booster(object):
|
|||||||
self.handle,
|
self.handle,
|
||||||
ctypes.byref(length),
|
ctypes.byref(length),
|
||||||
ctypes.byref(json_string)))
|
ctypes.byref(json_string)))
|
||||||
json_string = json_string.value.decode() # pylint: disable=no-member
|
assert json_string.value is not None
|
||||||
return json_string
|
result = json_string.value.decode() # pylint: disable=no-member
|
||||||
|
return result
|
||||||
|
|
||||||
def load_config(self, config):
|
def load_config(self, config: str) -> None:
|
||||||
'''Load configuration returned by `save_config`.
|
'''Load configuration returned by `save_config`.
|
||||||
|
|
||||||
.. versionadded:: 1.0.0
|
.. versionadded:: 1.0.0
|
||||||
@ -1502,14 +1497,14 @@ class Booster(object):
|
|||||||
self.handle,
|
self.handle,
|
||||||
c_str(config)))
|
c_str(config)))
|
||||||
|
|
||||||
def __copy__(self):
|
def __copy__(self) -> "Booster":
|
||||||
return self.__deepcopy__(None)
|
return self.__deepcopy__(None)
|
||||||
|
|
||||||
def __deepcopy__(self, _):
|
def __deepcopy__(self, _) -> "Booster":
|
||||||
'''Return a copy of booster.'''
|
'''Return a copy of booster.'''
|
||||||
return Booster(model_file=self)
|
return Booster(model_file=self)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self) -> "Booster":
|
||||||
"""Copy the booster object.
|
"""Copy the booster object.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -1519,7 +1514,7 @@ class Booster(object):
|
|||||||
"""
|
"""
|
||||||
return self.__copy__()
|
return self.__copy__()
|
||||||
|
|
||||||
def attr(self, key):
|
def attr(self, key: str) -> Optional[str]:
|
||||||
"""Get attribute string from the Booster.
|
"""Get attribute string from the Booster.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -1540,7 +1535,7 @@ class Booster(object):
|
|||||||
return py_str(ret.value)
|
return py_str(ret.value)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def attributes(self):
|
def attributes(self) -> Dict[str, str]:
|
||||||
"""Get attributes stored in the Booster as a dictionary.
|
"""Get attributes stored in the Booster as a dictionary.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -1572,7 +1567,7 @@ class Booster(object):
|
|||||||
_check_call(_LIB.XGBoosterSetAttr(
|
_check_call(_LIB.XGBoosterSetAttr(
|
||||||
self.handle, c_str(key), value))
|
self.handle, c_str(key), value))
|
||||||
|
|
||||||
def _get_feature_info(self, field: str):
|
def _get_feature_info(self, field: str) -> Optional[List[str]]:
|
||||||
length = c_bst_ulong()
|
length = c_bst_ulong()
|
||||||
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
||||||
if not hasattr(self, "handle") or self.handle is None:
|
if not hasattr(self, "handle") or self.handle is None:
|
||||||
@ -1585,22 +1580,6 @@ class Booster(object):
|
|||||||
feature_info = from_cstr_to_pystr(sarr, length)
|
feature_info = from_cstr_to_pystr(sarr, length)
|
||||||
return feature_info if feature_info else None
|
return feature_info if feature_info else None
|
||||||
|
|
||||||
@property
|
|
||||||
def feature_types(self) -> Optional[List[str]]:
|
|
||||||
"""Feature types for this booster. Can be directly set by input data or by
|
|
||||||
assignment.
|
|
||||||
|
|
||||||
"""
|
|
||||||
return self._get_feature_info("feature_type")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def feature_names(self) -> Optional[List[str]]:
|
|
||||||
"""Feature names for this booster. Can be directly set by input data or by
|
|
||||||
assignment.
|
|
||||||
|
|
||||||
"""
|
|
||||||
return self._get_feature_info("feature_name")
|
|
||||||
|
|
||||||
def _set_feature_info(self, features: Optional[List[str]], field: str) -> None:
|
def _set_feature_info(self, features: Optional[List[str]], field: str) -> None:
|
||||||
if features is not None:
|
if features is not None:
|
||||||
assert isinstance(features, list)
|
assert isinstance(features, list)
|
||||||
@ -1618,14 +1597,30 @@ class Booster(object):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@feature_names.setter
|
@property
|
||||||
def feature_names(self, features: Optional[List[str]]) -> None:
|
def feature_types(self) -> Optional[List[str]]:
|
||||||
self._set_feature_info(features, "feature_name")
|
"""Feature types for this booster. Can be directly set by input data or by
|
||||||
|
assignment.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return self._get_feature_info("feature_type")
|
||||||
|
|
||||||
@feature_types.setter
|
@feature_types.setter
|
||||||
def feature_types(self, features: Optional[List[str]]) -> None:
|
def feature_types(self, features: Optional[List[str]]) -> None:
|
||||||
self._set_feature_info(features, "feature_type")
|
self._set_feature_info(features, "feature_type")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def feature_names(self) -> Optional[List[str]]:
|
||||||
|
"""Feature names for this booster. Can be directly set by input data or by
|
||||||
|
assignment.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return self._get_feature_info("feature_name")
|
||||||
|
|
||||||
|
@feature_names.setter
|
||||||
|
def feature_names(self, features: Optional[List[str]]) -> None:
|
||||||
|
self._set_feature_info(features, "feature_name")
|
||||||
|
|
||||||
def set_param(self, params, value=None):
|
def set_param(self, params, value=None):
|
||||||
"""Set parameters into the Booster.
|
"""Set parameters into the Booster.
|
||||||
|
|
||||||
@ -1645,7 +1640,9 @@ class Booster(object):
|
|||||||
_check_call(_LIB.XGBoosterSetParam(self.handle, c_str(key),
|
_check_call(_LIB.XGBoosterSetParam(self.handle, c_str(key),
|
||||||
c_str(str(val))))
|
c_str(str(val))))
|
||||||
|
|
||||||
def update(self, dtrain, iteration, fobj=None):
|
def update(
|
||||||
|
self, dtrain: DMatrix, iteration: int, fobj: Optional[Objective] = None
|
||||||
|
) -> None:
|
||||||
"""Update for one iteration, with objective function calculated
|
"""Update for one iteration, with objective function calculated
|
||||||
internally. This function should not be called directly by users.
|
internally. This function should not be called directly by users.
|
||||||
|
|
||||||
@ -1672,18 +1669,18 @@ class Booster(object):
|
|||||||
grad, hess = fobj(pred, dtrain)
|
grad, hess = fobj(pred, dtrain)
|
||||||
self.boost(dtrain, grad, hess)
|
self.boost(dtrain, grad, hess)
|
||||||
|
|
||||||
def boost(self, dtrain, grad, hess):
|
def boost(self, dtrain: DMatrix, grad: np.ndarray, hess: np.ndarray) -> None:
|
||||||
"""Boost the booster for one iteration, with customized gradient
|
"""Boost the booster for one iteration, with customized gradient
|
||||||
statistics. Like :py:func:`xgboost.Booster.update`, this
|
statistics. Like :py:func:`xgboost.Booster.update`, this
|
||||||
function should not be called directly by users.
|
function should not be called directly by users.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
dtrain : DMatrix
|
dtrain :
|
||||||
The training DMatrix.
|
The training DMatrix.
|
||||||
grad : list
|
grad :
|
||||||
The first order of gradient.
|
The first order of gradient.
|
||||||
hess : list
|
hess :
|
||||||
The second order of gradient.
|
The second order of gradient.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -1700,17 +1697,23 @@ class Booster(object):
|
|||||||
c_array(ctypes.c_float, hess),
|
c_array(ctypes.c_float, hess),
|
||||||
c_bst_ulong(len(grad))))
|
c_bst_ulong(len(grad))))
|
||||||
|
|
||||||
def eval_set(self, evals, iteration=0, feval=None, output_margin=True):
|
def eval_set(
|
||||||
|
self,
|
||||||
|
evals: Sequence[Tuple[DMatrix, str]],
|
||||||
|
iteration: int = 0,
|
||||||
|
feval: Optional[Metric] = None,
|
||||||
|
output_margin: bool = True
|
||||||
|
) -> str:
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
"""Evaluate a set of data.
|
"""Evaluate a set of data.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
evals : list of tuples (DMatrix, string)
|
evals :
|
||||||
List of items to be evaluated.
|
List of items to be evaluated.
|
||||||
iteration : int
|
iteration :
|
||||||
Current iteration.
|
Current iteration.
|
||||||
feval : function
|
feval :
|
||||||
Custom evaluation function.
|
Custom evaluation function.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -1738,6 +1741,7 @@ class Booster(object):
|
|||||||
ctypes.byref(msg),
|
ctypes.byref(msg),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
assert msg.value is not None
|
||||||
res = msg.value.decode() # pylint: disable=no-member
|
res = msg.value.decode() # pylint: disable=no-member
|
||||||
if feval is not None:
|
if feval is not None:
|
||||||
for dmat, evname in evals:
|
for dmat, evname in evals:
|
||||||
@ -1754,18 +1758,18 @@ class Booster(object):
|
|||||||
res += "\t%s-%s:%f" % (evname, name, val)
|
res += "\t%s-%s:%f" % (evname, name, val)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def eval(self, data, name='eval', iteration=0):
|
def eval(self, data: DMatrix, name: str = 'eval', iteration: int = 0) -> str:
|
||||||
"""Evaluate the model on mat.
|
"""Evaluate the model on mat.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
data : DMatrix
|
data :
|
||||||
The dmatrix storing the input.
|
The dmatrix storing the input.
|
||||||
|
|
||||||
name : str, optional
|
name :
|
||||||
The name of the dataset.
|
The name of the dataset.
|
||||||
|
|
||||||
iteration : int, optional
|
iteration :
|
||||||
The current iteration number.
|
The current iteration number.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -2101,7 +2105,7 @@ class Booster(object):
|
|||||||
"Data type:" + str(type(data)) + " not supported by inplace prediction."
|
"Data type:" + str(type(data)) + " not supported by inplace prediction."
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_model(self, fname: Union[str, os.PathLike]):
|
def save_model(self, fname: Union[str, os.PathLike]) -> None:
|
||||||
"""Save the model to a file.
|
"""Save the model to a file.
|
||||||
|
|
||||||
The model is saved in an XGBoost internal format which is universal among the
|
The model is saved in an XGBoost internal format which is universal among the
|
||||||
@ -2124,7 +2128,7 @@ class Booster(object):
|
|||||||
else:
|
else:
|
||||||
raise TypeError("fname must be a string or os PathLike")
|
raise TypeError("fname must be a string or os PathLike")
|
||||||
|
|
||||||
def save_raw(self):
|
def save_raw(self) -> bytearray:
|
||||||
"""Save the model to a in memory buffer representation instead of file.
|
"""Save the model to a in memory buffer representation instead of file.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -2232,18 +2236,23 @@ class Booster(object):
|
|||||||
if need_close:
|
if need_close:
|
||||||
fout.close()
|
fout.close()
|
||||||
|
|
||||||
def get_dump(self, fmap='', with_stats=False, dump_format="text"):
|
def get_dump(
|
||||||
|
self,
|
||||||
|
fmap: Union[str, os.PathLike] = "",
|
||||||
|
with_stats: bool = False,
|
||||||
|
dump_format: str = "text"
|
||||||
|
) -> List[str]:
|
||||||
"""Returns the model dump as a list of strings. Unlike `save_model`, the
|
"""Returns the model dump as a list of strings. Unlike `save_model`, the
|
||||||
output format is primarily used for visualization or interpretation,
|
output format is primarily used for visualization or interpretation,
|
||||||
hence it's more human readable but cannot be loaded back to XGBoost.
|
hence it's more human readable but cannot be loaded back to XGBoost.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
fmap : string or os.PathLike, optional
|
fmap :
|
||||||
Name of the file containing feature map names.
|
Name of the file containing feature map names.
|
||||||
with_stats : bool, optional
|
with_stats :
|
||||||
Controls whether the split statistics are output.
|
Controls whether the split statistics are output.
|
||||||
dump_format : string, optional
|
dump_format :
|
||||||
Format of model dump. Can be 'text', 'json' or 'dot'.
|
Format of model dump. Can be 'text', 'json' or 'dot'.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -2259,7 +2268,9 @@ class Booster(object):
|
|||||||
res = from_cstr_to_pystr(sarr, length)
|
res = from_cstr_to_pystr(sarr, length)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def get_fscore(self, fmap=''):
|
def get_fscore(
|
||||||
|
self, fmap: Union[str, os.PathLike] = ""
|
||||||
|
) -> Dict[str, Union[float, List[float]]]:
|
||||||
"""Get feature importance of each feature.
|
"""Get feature importance of each feature.
|
||||||
|
|
||||||
.. note:: Zero-importance features will not be included
|
.. note:: Zero-importance features will not be included
|
||||||
@ -2269,7 +2280,7 @@ class Booster(object):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
fmap: str or os.PathLike (optional)
|
fmap :
|
||||||
The name of feature map file
|
The name of feature map file
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -2299,9 +2310,9 @@ class Booster(object):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
fmap: str or os.PathLike (optional)
|
fmap:
|
||||||
The name of feature map file.
|
The name of feature map file.
|
||||||
importance_type: str, default 'weight'
|
importance_type:
|
||||||
One of the importance types defined above.
|
One of the importance types defined above.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -2343,7 +2354,8 @@ class Booster(object):
|
|||||||
results[feat] = float(score)
|
results[feat] = float(score)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def trees_to_dataframe(self, fmap=''): # pylint: disable=too-many-statements
|
# pylint: disable=too-many-statements
|
||||||
|
def trees_to_dataframe(self, fmap: Union[str, os.PathLike] = '') -> DataFrame:
|
||||||
"""Parse a boosted tree model text dump into a pandas DataFrame structure.
|
"""Parse a boosted tree model text dump into a pandas DataFrame structure.
|
||||||
|
|
||||||
This feature is only defined when the decision tree model is chosen as base
|
This feature is only defined when the decision tree model is chosen as base
|
||||||
@ -2370,7 +2382,7 @@ class Booster(object):
|
|||||||
node_ids = []
|
node_ids = []
|
||||||
fids = []
|
fids = []
|
||||||
splits = []
|
splits = []
|
||||||
categories = []
|
categories: List[Optional[float]] = []
|
||||||
y_directs = []
|
y_directs = []
|
||||||
n_directs = []
|
n_directs = []
|
||||||
missings = []
|
missings = []
|
||||||
@ -2444,7 +2456,7 @@ class Booster(object):
|
|||||||
# pylint: disable=no-member
|
# pylint: disable=no-member
|
||||||
return df.sort(['Tree', 'Node']).reset_index(drop=True)
|
return df.sort(['Tree', 'Node']).reset_index(drop=True)
|
||||||
|
|
||||||
def _validate_features(self, data: DMatrix):
|
def _validate_features(self, data: DMatrix) -> None:
|
||||||
"""
|
"""
|
||||||
Validate Booster and data's feature_names are identical.
|
Validate Booster and data's feature_names are identical.
|
||||||
Set feature_names and feature_types from DMatrix
|
Set feature_names and feature_types from DMatrix
|
||||||
|
|||||||
@ -17,12 +17,13 @@ https://github.com/dask/dask-xgboost
|
|||||||
"""
|
"""
|
||||||
import platform
|
import platform
|
||||||
import logging
|
import logging
|
||||||
|
import collections
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Sequence
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from functools import partial, update_wrapper
|
from functools import partial, update_wrapper
|
||||||
from typing import TYPE_CHECKING, List, Tuple, Callable, Optional, Any, Union, Dict, Set
|
from typing import TYPE_CHECKING, List, Tuple, Callable, Optional, Any, Union, Dict, Set
|
||||||
|
from typing import Sequence
|
||||||
from typing import Awaitable, Generator, TypeVar
|
from typing import Awaitable, Generator, TypeVar
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
@ -524,9 +525,9 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
|||||||
self._feature_names = feature_names
|
self._feature_names = feature_names
|
||||||
self._feature_types = feature_types
|
self._feature_types = feature_types
|
||||||
|
|
||||||
assert isinstance(self._data, Sequence)
|
assert isinstance(self._data, collections.abc.Sequence)
|
||||||
|
|
||||||
types = (Sequence, type(None))
|
types = (collections.abc.Sequence, type(None))
|
||||||
assert isinstance(self._labels, types)
|
assert isinstance(self._labels, types)
|
||||||
assert isinstance(self._weights, types)
|
assert isinstance(self._weights, types)
|
||||||
assert isinstance(self._base_margin, types)
|
assert isinstance(self._base_margin, types)
|
||||||
@ -817,7 +818,7 @@ async def _get_rabit_args(n_workers: int, client: "distributed.Client") -> List[
|
|||||||
|
|
||||||
def _get_workers_from_data(
|
def _get_workers_from_data(
|
||||||
dtrain: DaskDMatrix,
|
dtrain: DaskDMatrix,
|
||||||
evals: Optional[List[Tuple[DaskDMatrix, str]]]
|
evals: Optional[Sequence[Tuple[DaskDMatrix, str]]]
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
X_worker_map: Set[str] = set(dtrain.worker_map.keys())
|
X_worker_map: Set[str] = set(dtrain.worker_map.keys())
|
||||||
if evals:
|
if evals:
|
||||||
@ -837,13 +838,13 @@ async def _train_async(
|
|||||||
params: Dict[str, Any],
|
params: Dict[str, Any],
|
||||||
dtrain: DaskDMatrix,
|
dtrain: DaskDMatrix,
|
||||||
num_boost_round: int,
|
num_boost_round: int,
|
||||||
evals: Optional[List[Tuple[DaskDMatrix, str]]],
|
evals: Optional[Sequence[Tuple[DaskDMatrix, str]]],
|
||||||
obj: Optional[Objective],
|
obj: Optional[Objective],
|
||||||
feval: Optional[Metric],
|
feval: Optional[Metric],
|
||||||
early_stopping_rounds: Optional[int],
|
early_stopping_rounds: Optional[int],
|
||||||
verbose_eval: Union[int, bool],
|
verbose_eval: Union[int, bool],
|
||||||
xgb_model: Optional[Booster],
|
xgb_model: Optional[Booster],
|
||||||
callbacks: Optional[List[TrainingCallback]],
|
callbacks: Optional[Sequence[TrainingCallback]],
|
||||||
custom_metric: Optional[Metric],
|
custom_metric: Optional[Metric],
|
||||||
) -> Optional[TrainReturnT]:
|
) -> Optional[TrainReturnT]:
|
||||||
workers = _get_workers_from_data(dtrain, evals)
|
workers = _get_workers_from_data(dtrain, evals)
|
||||||
@ -951,13 +952,13 @@ def train( # pylint: disable=unused-argument
|
|||||||
dtrain: DaskDMatrix,
|
dtrain: DaskDMatrix,
|
||||||
num_boost_round: int = 10,
|
num_boost_round: int = 10,
|
||||||
*,
|
*,
|
||||||
evals: Optional[List[Tuple[DaskDMatrix, str]]] = None,
|
evals: Optional[Sequence[Tuple[DaskDMatrix, str]]] = None,
|
||||||
obj: Optional[Objective] = None,
|
obj: Optional[Objective] = None,
|
||||||
feval: Optional[Metric] = None,
|
feval: Optional[Metric] = None,
|
||||||
early_stopping_rounds: Optional[int] = None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
xgb_model: Optional[Booster] = None,
|
xgb_model: Optional[Booster] = None,
|
||||||
verbose_eval: Union[int, bool] = True,
|
verbose_eval: Union[int, bool] = True,
|
||||||
callbacks: Optional[List[TrainingCallback]] = None,
|
callbacks: Optional[Sequence[TrainingCallback]] = None,
|
||||||
custom_metric: Optional[Metric] = None,
|
custom_metric: Optional[Metric] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Train XGBoost model.
|
"""Train XGBoost model.
|
||||||
@ -1648,15 +1649,15 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
|||||||
y: _DaskCollection,
|
y: _DaskCollection,
|
||||||
sample_weight: Optional[_DaskCollection],
|
sample_weight: Optional[_DaskCollection],
|
||||||
base_margin: Optional[_DaskCollection],
|
base_margin: Optional[_DaskCollection],
|
||||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
|
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]],
|
||||||
eval_metric: Optional[Union[str, List[str], Metric]],
|
eval_metric: Optional[Union[str, Sequence[str], Metric]],
|
||||||
sample_weight_eval_set: Optional[List[_DaskCollection]],
|
sample_weight_eval_set: Optional[Sequence[_DaskCollection]],
|
||||||
base_margin_eval_set: Optional[List[_DaskCollection]],
|
base_margin_eval_set: Optional[Sequence[_DaskCollection]],
|
||||||
early_stopping_rounds: Optional[int],
|
early_stopping_rounds: Optional[int],
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
xgb_model: Optional[Union[Booster, XGBModel]],
|
xgb_model: Optional[Union[Booster, XGBModel]],
|
||||||
feature_weights: Optional[_DaskCollection],
|
feature_weights: Optional[_DaskCollection],
|
||||||
callbacks: Optional[List[TrainingCallback]],
|
callbacks: Optional[Sequence[TrainingCallback]],
|
||||||
) -> _DaskCollection:
|
) -> _DaskCollection:
|
||||||
params = self.get_xgb_params()
|
params = self.get_xgb_params()
|
||||||
dtrain, evals = await _async_wrap_evaluation_matrices(
|
dtrain, evals = await _async_wrap_evaluation_matrices(
|
||||||
@ -1714,15 +1715,15 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
|||||||
*,
|
*,
|
||||||
sample_weight: Optional[_DaskCollection] = None,
|
sample_weight: Optional[_DaskCollection] = None,
|
||||||
base_margin: Optional[_DaskCollection] = None,
|
base_margin: Optional[_DaskCollection] = None,
|
||||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
|
||||||
early_stopping_rounds: Optional[int] = None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
||||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||||
base_margin_eval_set: Optional[List[_DaskCollection]] = None,
|
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||||
feature_weights: Optional[_DaskCollection] = None,
|
feature_weights: Optional[_DaskCollection] = None,
|
||||||
callbacks: Optional[List[TrainingCallback]] = None,
|
callbacks: Optional[Sequence[TrainingCallback]] = None,
|
||||||
) -> "DaskXGBRegressor":
|
) -> "DaskXGBRegressor":
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||||
@ -1738,15 +1739,15 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
self, X: _DaskCollection, y: _DaskCollection,
|
self, X: _DaskCollection, y: _DaskCollection,
|
||||||
sample_weight: Optional[_DaskCollection],
|
sample_weight: Optional[_DaskCollection],
|
||||||
base_margin: Optional[_DaskCollection],
|
base_margin: Optional[_DaskCollection],
|
||||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
|
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]],
|
||||||
eval_metric: Optional[Union[str, List[str], Metric]],
|
eval_metric: Optional[Union[str, Sequence[str], Metric]],
|
||||||
sample_weight_eval_set: Optional[List[_DaskCollection]],
|
sample_weight_eval_set: Optional[Sequence[_DaskCollection]],
|
||||||
base_margin_eval_set: Optional[List[_DaskCollection]],
|
base_margin_eval_set: Optional[Sequence[_DaskCollection]],
|
||||||
early_stopping_rounds: Optional[int],
|
early_stopping_rounds: Optional[int],
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
xgb_model: Optional[Union[Booster, XGBModel]],
|
xgb_model: Optional[Union[Booster, XGBModel]],
|
||||||
feature_weights: Optional[_DaskCollection],
|
feature_weights: Optional[_DaskCollection],
|
||||||
callbacks: Optional[List[TrainingCallback]]
|
callbacks: Optional[Sequence[TrainingCallback]]
|
||||||
) -> "DaskXGBClassifier":
|
) -> "DaskXGBClassifier":
|
||||||
params = self.get_xgb_params()
|
params = self.get_xgb_params()
|
||||||
dtrain, evals = await _async_wrap_evaluation_matrices(
|
dtrain, evals = await _async_wrap_evaluation_matrices(
|
||||||
@ -1818,15 +1819,15 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
*,
|
*,
|
||||||
sample_weight: Optional[_DaskCollection] = None,
|
sample_weight: Optional[_DaskCollection] = None,
|
||||||
base_margin: Optional[_DaskCollection] = None,
|
base_margin: Optional[_DaskCollection] = None,
|
||||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
|
||||||
early_stopping_rounds: Optional[int] = None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
||||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||||
base_margin_eval_set: Optional[List[_DaskCollection]] = None,
|
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||||
feature_weights: Optional[_DaskCollection] = None,
|
feature_weights: Optional[_DaskCollection] = None,
|
||||||
callbacks: Optional[List[TrainingCallback]] = None
|
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||||
) -> "DaskXGBClassifier":
|
) -> "DaskXGBClassifier":
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||||
@ -1935,17 +1936,17 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
|||||||
qid: Optional[_DaskCollection],
|
qid: Optional[_DaskCollection],
|
||||||
sample_weight: Optional[_DaskCollection],
|
sample_weight: Optional[_DaskCollection],
|
||||||
base_margin: Optional[_DaskCollection],
|
base_margin: Optional[_DaskCollection],
|
||||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
|
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]],
|
||||||
sample_weight_eval_set: Optional[List[_DaskCollection]],
|
sample_weight_eval_set: Optional[Sequence[_DaskCollection]],
|
||||||
base_margin_eval_set: Optional[List[_DaskCollection]],
|
base_margin_eval_set: Optional[Sequence[_DaskCollection]],
|
||||||
eval_group: Optional[List[_DaskCollection]],
|
eval_group: Optional[Sequence[_DaskCollection]],
|
||||||
eval_qid: Optional[List[_DaskCollection]],
|
eval_qid: Optional[Sequence[_DaskCollection]],
|
||||||
eval_metric: Optional[Union[str, List[str], Metric]],
|
eval_metric: Optional[Union[str, Sequence[str], Metric]],
|
||||||
early_stopping_rounds: Optional[int],
|
early_stopping_rounds: Optional[int],
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
xgb_model: Optional[Union[XGBModel, Booster]],
|
xgb_model: Optional[Union[XGBModel, Booster]],
|
||||||
feature_weights: Optional[_DaskCollection],
|
feature_weights: Optional[_DaskCollection],
|
||||||
callbacks: Optional[List[TrainingCallback]],
|
callbacks: Optional[Sequence[TrainingCallback]],
|
||||||
) -> "DaskXGBRanker":
|
) -> "DaskXGBRanker":
|
||||||
msg = "Use `qid` instead of `group` on dask interface."
|
msg = "Use `qid` instead of `group` on dask interface."
|
||||||
if not (group is None and eval_group is None):
|
if not (group is None and eval_group is None):
|
||||||
@ -2010,17 +2011,17 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
|||||||
qid: Optional[_DaskCollection] = None,
|
qid: Optional[_DaskCollection] = None,
|
||||||
sample_weight: Optional[_DaskCollection] = None,
|
sample_weight: Optional[_DaskCollection] = None,
|
||||||
base_margin: Optional[_DaskCollection] = None,
|
base_margin: Optional[_DaskCollection] = None,
|
||||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||||
eval_group: Optional[List[_DaskCollection]] = None,
|
eval_group: Optional[Sequence[_DaskCollection]] = None,
|
||||||
eval_qid: Optional[List[_DaskCollection]] = None,
|
eval_qid: Optional[Sequence[_DaskCollection]] = None,
|
||||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
|
||||||
early_stopping_rounds: int = None,
|
early_stopping_rounds: int = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
xgb_model: Optional[Union[XGBModel, Booster]] = None,
|
xgb_model: Optional[Union[XGBModel, Booster]] = None,
|
||||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||||
base_margin_eval_set: Optional[List[_DaskCollection]] = None,
|
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||||
feature_weights: Optional[_DaskCollection] = None,
|
feature_weights: Optional[_DaskCollection] = None,
|
||||||
callbacks: Optional[List[TrainingCallback]] = None
|
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||||
) -> "DaskXGBRanker":
|
) -> "DaskXGBRanker":
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||||
@ -2077,15 +2078,15 @@ class DaskXGBRFRegressor(DaskXGBRegressor):
|
|||||||
*,
|
*,
|
||||||
sample_weight: Optional[_DaskCollection] = None,
|
sample_weight: Optional[_DaskCollection] = None,
|
||||||
base_margin: Optional[_DaskCollection] = None,
|
base_margin: Optional[_DaskCollection] = None,
|
||||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
|
||||||
early_stopping_rounds: Optional[int] = None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
||||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||||
base_margin_eval_set: Optional[List[_DaskCollection]] = None,
|
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||||
feature_weights: Optional[_DaskCollection] = None,
|
feature_weights: Optional[_DaskCollection] = None,
|
||||||
callbacks: Optional[List[TrainingCallback]] = None
|
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||||
) -> "DaskXGBRFRegressor":
|
) -> "DaskXGBRFRegressor":
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||||
@ -2141,15 +2142,15 @@ class DaskXGBRFClassifier(DaskXGBClassifier):
|
|||||||
*,
|
*,
|
||||||
sample_weight: Optional[_DaskCollection] = None,
|
sample_weight: Optional[_DaskCollection] = None,
|
||||||
base_margin: Optional[_DaskCollection] = None,
|
base_margin: Optional[_DaskCollection] = None,
|
||||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
eval_metric: Optional[Union[str, Sequence[str], Callable]] = None,
|
||||||
early_stopping_rounds: Optional[int] = None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
||||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||||
base_margin_eval_set: Optional[List[_DaskCollection]] = None,
|
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
|
||||||
feature_weights: Optional[_DaskCollection] = None,
|
feature_weights: Optional[_DaskCollection] = None,
|
||||||
callbacks: Optional[List[TrainingCallback]] = None
|
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||||
) -> "DaskXGBRFClassifier":
|
) -> "DaskXGBRFClassifier":
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||||
|
|||||||
@ -4,7 +4,8 @@ import copy
|
|||||||
import warnings
|
import warnings
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Union, Optional, List, Dict, Callable, Tuple, Any, TypeVar, Type
|
from typing import Union, Optional, List, Dict, Callable, Tuple, Any, TypeVar, Type, cast
|
||||||
|
from typing import Sequence
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .core import Booster, DMatrix, XGBoostError
|
from .core import Booster, DMatrix, XGBoostError
|
||||||
@ -36,7 +37,7 @@ class XGBRankerMixIn: # pylint: disable=too-few-public-methods
|
|||||||
|
|
||||||
def _check_rf_callback(
|
def _check_rf_callback(
|
||||||
early_stopping_rounds: Optional[int],
|
early_stopping_rounds: Optional[int],
|
||||||
callbacks: Optional[List[TrainingCallback]],
|
callbacks: Optional[Sequence[TrainingCallback]],
|
||||||
) -> None:
|
) -> None:
|
||||||
if early_stopping_rounds is not None or callbacks is not None:
|
if early_stopping_rounds is not None or callbacks is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -343,14 +344,14 @@ def _wrap_evaluation_matrices(
|
|||||||
sample_weight: Optional[Any],
|
sample_weight: Optional[Any],
|
||||||
base_margin: Optional[Any],
|
base_margin: Optional[Any],
|
||||||
feature_weights: Optional[Any],
|
feature_weights: Optional[Any],
|
||||||
eval_set: Optional[List[Tuple[Any, Any]]],
|
eval_set: Optional[Sequence[Tuple[Any, Any]]],
|
||||||
sample_weight_eval_set: Optional[List[Any]],
|
sample_weight_eval_set: Optional[Sequence[Any]],
|
||||||
base_margin_eval_set: Optional[List[Any]],
|
base_margin_eval_set: Optional[Sequence[Any]],
|
||||||
eval_group: Optional[List[Any]],
|
eval_group: Optional[Sequence[Any]],
|
||||||
eval_qid: Optional[List[Any]],
|
eval_qid: Optional[Sequence[Any]],
|
||||||
create_dmatrix: Callable,
|
create_dmatrix: Callable,
|
||||||
enable_categorical: bool,
|
enable_categorical: bool,
|
||||||
) -> Tuple[Any, Optional[List[Tuple[Any, str]]]]:
|
) -> Tuple[Any, List[Tuple[Any, str]]]:
|
||||||
"""Convert array_like evaluation matrices into DMatrix. Perform validation on the way.
|
"""Convert array_like evaluation matrices into DMatrix. Perform validation on the way.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -368,7 +369,7 @@ def _wrap_evaluation_matrices(
|
|||||||
|
|
||||||
n_validation = 0 if eval_set is None else len(eval_set)
|
n_validation = 0 if eval_set is None else len(eval_set)
|
||||||
|
|
||||||
def validate_or_none(meta: Optional[List], name: str) -> List:
|
def validate_or_none(meta: Optional[Sequence], name: str) -> Sequence:
|
||||||
if meta is None:
|
if meta is None:
|
||||||
return [None] * n_validation
|
return [None] * n_validation
|
||||||
if len(meta) != n_validation:
|
if len(meta) != n_validation:
|
||||||
@ -464,7 +465,7 @@ class XGBModel(XGBModelBase):
|
|||||||
missing: float = np.nan,
|
missing: float = np.nan,
|
||||||
num_parallel_tree: Optional[int] = None,
|
num_parallel_tree: Optional[int] = None,
|
||||||
monotone_constraints: Optional[Union[Dict[str, int], str]] = None,
|
monotone_constraints: Optional[Union[Dict[str, int], str]] = None,
|
||||||
interaction_constraints: Optional[Union[str, List[Tuple[str]]]] = None,
|
interaction_constraints: Optional[Union[str, Sequence[Sequence[str]]]] = None,
|
||||||
importance_type: Optional[str] = None,
|
importance_type: Optional[str] = None,
|
||||||
gpu_id: Optional[int] = None,
|
gpu_id: Optional[int] = None,
|
||||||
validate_parameters: Optional[bool] = None,
|
validate_parameters: Optional[bool] = None,
|
||||||
@ -715,7 +716,7 @@ class XGBModel(XGBModelBase):
|
|||||||
def _configure_fit(
|
def _configure_fit(
|
||||||
self,
|
self,
|
||||||
booster: Optional[Union[Booster, "XGBModel", str]],
|
booster: Optional[Union[Booster, "XGBModel", str]],
|
||||||
eval_metric: Optional[Union[Callable, str, List[str]]],
|
eval_metric: Optional[Union[Callable, str, Sequence[str]]],
|
||||||
params: Dict[str, Any],
|
params: Dict[str, Any],
|
||||||
early_stopping_rounds: Optional[int],
|
early_stopping_rounds: Optional[int],
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
@ -788,10 +789,7 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
def _set_evaluation_result(self, evals_result: TrainingCallback.EvalsLog) -> None:
|
def _set_evaluation_result(self, evals_result: TrainingCallback.EvalsLog) -> None:
|
||||||
if evals_result:
|
if evals_result:
|
||||||
for val in evals_result.items():
|
self.evals_result_ = cast(Dict[str, Dict[str, List[float]]], evals_result)
|
||||||
evals_result_key = list(val[1].keys())[0]
|
|
||||||
evals_result[val[0]][evals_result_key] = val[1][evals_result_key]
|
|
||||||
self.evals_result_ = evals_result
|
|
||||||
|
|
||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def fit(
|
def fit(
|
||||||
@ -801,15 +799,15 @@ class XGBModel(XGBModelBase):
|
|||||||
*,
|
*,
|
||||||
sample_weight: Optional[array_like] = None,
|
sample_weight: Optional[array_like] = None,
|
||||||
base_margin: Optional[array_like] = None,
|
base_margin: Optional[array_like] = None,
|
||||||
eval_set: Optional[List[Tuple[array_like, array_like]]] = None,
|
eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None,
|
||||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
|
||||||
early_stopping_rounds: Optional[int] = None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
verbose: Optional[bool] = True,
|
verbose: Optional[bool] = True,
|
||||||
xgb_model: Optional[Union[Booster, str, "XGBModel"]] = None,
|
xgb_model: Optional[Union[Booster, str, "XGBModel"]] = None,
|
||||||
sample_weight_eval_set: Optional[List[array_like]] = None,
|
sample_weight_eval_set: Optional[Sequence[array_like]] = None,
|
||||||
base_margin_eval_set: Optional[List[array_like]] = None,
|
base_margin_eval_set: Optional[Sequence[array_like]] = None,
|
||||||
feature_weights: Optional[array_like] = None,
|
feature_weights: Optional[array_like] = None,
|
||||||
callbacks: Optional[List[TrainingCallback]] = None
|
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||||
) -> "XGBModel":
|
) -> "XGBModel":
|
||||||
# pylint: disable=invalid-name,attribute-defined-outside-init
|
# pylint: disable=invalid-name,attribute-defined-outside-init
|
||||||
"""Fit gradient boosting model.
|
"""Fit gradient boosting model.
|
||||||
@ -1031,7 +1029,7 @@ class XGBModel(XGBModelBase):
|
|||||||
Input features matrix.
|
Input features matrix.
|
||||||
|
|
||||||
iteration_range :
|
iteration_range :
|
||||||
See :py:meth:`xgboost.XGBRegressor.predict`.
|
See :py:meth:`predict`.
|
||||||
|
|
||||||
ntree_limit :
|
ntree_limit :
|
||||||
Deprecated, use ``iteration_range`` instead.
|
Deprecated, use ``iteration_range`` instead.
|
||||||
@ -1055,40 +1053,26 @@ class XGBModel(XGBModelBase):
|
|||||||
iteration_range=iteration_range
|
iteration_range=iteration_range
|
||||||
)
|
)
|
||||||
|
|
||||||
def evals_result(self) -> TrainingCallback.EvalsLog:
|
def evals_result(self) -> Dict[str, Dict[str, List[float]]]:
|
||||||
"""Return the evaluation results.
|
"""Return the evaluation results.
|
||||||
|
|
||||||
If **eval_set** is passed to the `fit` function, you can call
|
If **eval_set** is passed to the :py:meth:`fit` function, you can call
|
||||||
``evals_result()`` to get evaluation results for all passed **eval_sets**.
|
``evals_result()`` to get evaluation results for all passed **eval_sets**. When
|
||||||
When **eval_metric** is also passed to the `fit` function, the
|
**eval_metric** is also passed to the :py:meth:`fit` function, the
|
||||||
**evals_result** will contain the **eval_metrics** passed to the `fit` function.
|
**evals_result** will contain the **eval_metrics** passed to the :py:meth:`fit`
|
||||||
|
function.
|
||||||
|
|
||||||
Returns
|
The returned evaluation result is a dictionary:
|
||||||
-------
|
|
||||||
evals_result : dictionary
|
|
||||||
|
|
||||||
Example
|
|
||||||
-------
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
param_dist = {'objective':'binary:logistic', 'n_estimators':2}
|
|
||||||
|
|
||||||
clf = xgb.XGBModel(**param_dist)
|
|
||||||
|
|
||||||
clf.fit(X_train, y_train,
|
|
||||||
eval_set=[(X_train, y_train), (X_test, y_test)],
|
|
||||||
eval_metric='logloss',
|
|
||||||
verbose=True)
|
|
||||||
|
|
||||||
evals_result = clf.evals_result()
|
|
||||||
|
|
||||||
The variable **evals_result** will contain:
|
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
{'validation_0': {'logloss': ['0.604835', '0.531479']},
|
{'validation_0': {'logloss': ['0.604835', '0.531479']},
|
||||||
'validation_1': {'logloss': ['0.41965', '0.17686']}}
|
'validation_1': {'logloss': ['0.41965', '0.17686']}}
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
evals_result
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if getattr(self, "evals_result_", None) is not None:
|
if getattr(self, "evals_result_", None) is not None:
|
||||||
evals_result = self.evals_result_
|
evals_result = self.evals_result_
|
||||||
@ -1193,8 +1177,8 @@ class XGBModel(XGBModelBase):
|
|||||||
.. note:: Intercept is defined only for linear learners
|
.. note:: Intercept is defined only for linear learners
|
||||||
|
|
||||||
Intercept (bias) is only defined when the linear model is chosen as base
|
Intercept (bias) is only defined when the linear model is chosen as base
|
||||||
learner (`booster=gblinear`). It is not defined for other base learner types, such
|
learner (`booster=gblinear`). It is not defined for other base learner types,
|
||||||
as tree learners (`booster=gbtree`).
|
such as tree learners (`booster=gbtree`).
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -1251,15 +1235,15 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
*,
|
*,
|
||||||
sample_weight: Optional[array_like] = None,
|
sample_weight: Optional[array_like] = None,
|
||||||
base_margin: Optional[array_like] = None,
|
base_margin: Optional[array_like] = None,
|
||||||
eval_set: Optional[List[Tuple[array_like, array_like]]] = None,
|
eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None,
|
||||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
|
||||||
early_stopping_rounds: Optional[int] = None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
verbose: Optional[bool] = True,
|
verbose: Optional[bool] = True,
|
||||||
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
||||||
sample_weight_eval_set: Optional[List[array_like]] = None,
|
sample_weight_eval_set: Optional[Sequence[array_like]] = None,
|
||||||
base_margin_eval_set: Optional[List[array_like]] = None,
|
base_margin_eval_set: Optional[Sequence[array_like]] = None,
|
||||||
feature_weights: Optional[array_like] = None,
|
feature_weights: Optional[array_like] = None,
|
||||||
callbacks: Optional[List[TrainingCallback]] = None
|
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||||
) -> "XGBClassifier":
|
) -> "XGBClassifier":
|
||||||
# pylint: disable = attribute-defined-outside-init,too-many-statements
|
# pylint: disable = attribute-defined-outside-init,too-many-statements
|
||||||
evals_result: TrainingCallback.EvalsLog = {}
|
evals_result: TrainingCallback.EvalsLog = {}
|
||||||
@ -1445,51 +1429,6 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
getattr(self, "n_classes_", None), class_probs, np.vstack
|
getattr(self, "n_classes_", None), class_probs, np.vstack
|
||||||
)
|
)
|
||||||
|
|
||||||
def evals_result(self) -> TrainingCallback.EvalsLog:
|
|
||||||
"""Return the evaluation results.
|
|
||||||
|
|
||||||
If **eval_set** is passed to the `fit` function, you can call
|
|
||||||
``evals_result()`` to get evaluation results for all passed **eval_sets**.
|
|
||||||
|
|
||||||
When **eval_metric** is also passed as a parameter, the **evals_result** will
|
|
||||||
contain the **eval_metric** passed to the `fit` function.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
evals_result : dictionary
|
|
||||||
|
|
||||||
Example
|
|
||||||
-------
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
param_dist = {
|
|
||||||
'objective':'binary:logistic', 'n_estimators':2, eval_metric="logloss"
|
|
||||||
}
|
|
||||||
|
|
||||||
clf = xgb.XGBClassifier(**param_dist)
|
|
||||||
|
|
||||||
clf.fit(X_train, y_train,
|
|
||||||
eval_set=[(X_train, y_train), (X_test, y_test)],
|
|
||||||
verbose=True)
|
|
||||||
|
|
||||||
evals_result = clf.evals_result()
|
|
||||||
|
|
||||||
The variable **evals_result** will contain
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
{'validation_0': {'logloss': ['0.604835', '0.531479']},
|
|
||||||
'validation_1': {'logloss': ['0.41965', '0.17686']}}
|
|
||||||
|
|
||||||
"""
|
|
||||||
if self.evals_result_:
|
|
||||||
evals_result = self.evals_result_
|
|
||||||
else:
|
|
||||||
raise XGBoostError('No results.')
|
|
||||||
|
|
||||||
return evals_result
|
|
||||||
|
|
||||||
|
|
||||||
@xgboost_model_doc(
|
@xgboost_model_doc(
|
||||||
"scikit-learn API for XGBoost random forest classification.",
|
"scikit-learn API for XGBoost random forest classification.",
|
||||||
@ -1533,15 +1472,15 @@ class XGBRFClassifier(XGBClassifier):
|
|||||||
*,
|
*,
|
||||||
sample_weight: Optional[array_like] = None,
|
sample_weight: Optional[array_like] = None,
|
||||||
base_margin: Optional[array_like] = None,
|
base_margin: Optional[array_like] = None,
|
||||||
eval_set: Optional[List[Tuple[array_like, array_like]]] = None,
|
eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None,
|
||||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
|
||||||
early_stopping_rounds: Optional[int] = None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
verbose: Optional[bool] = True,
|
verbose: Optional[bool] = True,
|
||||||
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
||||||
sample_weight_eval_set: Optional[List[array_like]] = None,
|
sample_weight_eval_set: Optional[Sequence[array_like]] = None,
|
||||||
base_margin_eval_set: Optional[List[array_like]] = None,
|
base_margin_eval_set: Optional[Sequence[array_like]] = None,
|
||||||
feature_weights: Optional[array_like] = None,
|
feature_weights: Optional[array_like] = None,
|
||||||
callbacks: Optional[List[TrainingCallback]] = None
|
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||||
) -> "XGBRFClassifier":
|
) -> "XGBRFClassifier":
|
||||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||||
_check_rf_callback(early_stopping_rounds, callbacks)
|
_check_rf_callback(early_stopping_rounds, callbacks)
|
||||||
@ -1605,15 +1544,15 @@ class XGBRFRegressor(XGBRegressor):
|
|||||||
*,
|
*,
|
||||||
sample_weight: Optional[array_like] = None,
|
sample_weight: Optional[array_like] = None,
|
||||||
base_margin: Optional[array_like] = None,
|
base_margin: Optional[array_like] = None,
|
||||||
eval_set: Optional[List[Tuple[array_like, array_like]]] = None,
|
eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None,
|
||||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
|
||||||
early_stopping_rounds: Optional[int] = None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
verbose: Optional[bool] = True,
|
verbose: Optional[bool] = True,
|
||||||
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
||||||
sample_weight_eval_set: Optional[List[array_like]] = None,
|
sample_weight_eval_set: Optional[Sequence[array_like]] = None,
|
||||||
base_margin_eval_set: Optional[List[array_like]] = None,
|
base_margin_eval_set: Optional[Sequence[array_like]] = None,
|
||||||
feature_weights: Optional[array_like] = None,
|
feature_weights: Optional[array_like] = None,
|
||||||
callbacks: Optional[List[TrainingCallback]] = None
|
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||||
) -> "XGBRFRegressor":
|
) -> "XGBRFRegressor":
|
||||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||||
_check_rf_callback(early_stopping_rounds, callbacks)
|
_check_rf_callback(early_stopping_rounds, callbacks)
|
||||||
@ -1682,17 +1621,17 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
qid: Optional[array_like] = None,
|
qid: Optional[array_like] = None,
|
||||||
sample_weight: Optional[array_like] = None,
|
sample_weight: Optional[array_like] = None,
|
||||||
base_margin: Optional[array_like] = None,
|
base_margin: Optional[array_like] = None,
|
||||||
eval_set: Optional[List[Tuple[array_like, array_like]]] = None,
|
eval_set: Optional[Sequence[Tuple[array_like, array_like]]] = None,
|
||||||
eval_group: Optional[List[array_like]] = None,
|
eval_group: Optional[Sequence[array_like]] = None,
|
||||||
eval_qid: Optional[List[array_like]] = None,
|
eval_qid: Optional[Sequence[array_like]] = None,
|
||||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
|
||||||
early_stopping_rounds: Optional[int] = None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
verbose: Optional[bool] = False,
|
verbose: Optional[bool] = False,
|
||||||
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
||||||
sample_weight_eval_set: Optional[List[array_like]] = None,
|
sample_weight_eval_set: Optional[Sequence[array_like]] = None,
|
||||||
base_margin_eval_set: Optional[List[array_like]] = None,
|
base_margin_eval_set: Optional[Sequence[array_like]] = None,
|
||||||
feature_weights: Optional[array_like] = None,
|
feature_weights: Optional[array_like] = None,
|
||||||
callbacks: Optional[List[TrainingCallback]] = None
|
callbacks: Optional[Sequence[TrainingCallback]] = None
|
||||||
) -> "XGBRanker":
|
) -> "XGBRanker":
|
||||||
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||||
"""Fit gradient boosting ranker
|
"""Fit gradient boosting ranker
|
||||||
|
|||||||
@ -3,18 +3,20 @@
|
|||||||
# pylint: disable=too-many-branches, too-many-statements
|
# pylint: disable=too-many-branches, too-many-statements
|
||||||
"""Training Library containing training routines."""
|
"""Training Library containing training routines."""
|
||||||
import copy
|
import copy
|
||||||
from typing import Optional, List
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
from typing import Optional, Dict, Any, Union, Tuple, cast, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .core import Booster, XGBoostError, _get_booster_layer_trees
|
from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees
|
||||||
from .core import _deprecate_positional_args
|
from .core import Metric, Objective
|
||||||
from .core import Objective, Metric
|
|
||||||
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
||||||
from . import callback
|
from . import callback
|
||||||
|
|
||||||
|
|
||||||
def _assert_new_callback(callbacks: Optional[List[callback.TrainingCallback]]) -> None:
|
def _assert_new_callback(
|
||||||
|
callbacks: Optional[Sequence[callback.TrainingCallback]]
|
||||||
|
) -> None:
|
||||||
is_new_callback: bool = not callbacks or all(
|
is_new_callback: bool = not callbacks or all(
|
||||||
isinstance(c, callback.TrainingCallback) for c in callbacks
|
isinstance(c, callback.TrainingCallback) for c in callbacks
|
||||||
)
|
)
|
||||||
@ -44,24 +46,24 @@ def _configure_custom_metric(
|
|||||||
|
|
||||||
|
|
||||||
def _train_internal(
|
def _train_internal(
|
||||||
params,
|
params: Dict[str, Any],
|
||||||
dtrain,
|
dtrain: DMatrix,
|
||||||
num_boost_round=10,
|
num_boost_round: int = 10,
|
||||||
evals=(),
|
evals: Optional[Sequence[Tuple[DMatrix, str]]] = None,
|
||||||
obj=None,
|
obj: Optional[Objective] = None,
|
||||||
feval=None,
|
feval: Optional[Metric] = None,
|
||||||
custom_metric=None,
|
custom_metric: Optional[Metric] = None,
|
||||||
xgb_model=None,
|
xgb_model: Optional[Union[str, os.PathLike, Booster, bytearray]] = None,
|
||||||
callbacks=None,
|
callbacks: Optional[Sequence[callback.TrainingCallback]] = None,
|
||||||
evals_result=None,
|
evals_result: callback.TrainingCallback.EvalsLog = None,
|
||||||
maximize=None,
|
maximize: Optional[bool] = None,
|
||||||
verbose_eval=None,
|
verbose_eval: Optional[Union[bool, int]] = True,
|
||||||
early_stopping_rounds=None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
):
|
) -> Booster:
|
||||||
"""internal training function"""
|
"""internal training function"""
|
||||||
callbacks = [] if callbacks is None else copy.copy(callbacks)
|
callbacks = [] if callbacks is None else copy.copy(list(callbacks))
|
||||||
metric_fn = _configure_custom_metric(feval, custom_metric)
|
metric_fn = _configure_custom_metric(feval, custom_metric)
|
||||||
evals = list(evals)
|
evals = list(evals) if evals else []
|
||||||
|
|
||||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||||
|
|
||||||
@ -78,7 +80,7 @@ def _train_internal(
|
|||||||
callbacks.append(
|
callbacks.append(
|
||||||
callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
|
callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
|
||||||
)
|
)
|
||||||
callbacks = callback.CallbackContainer(
|
cb_container = callback.CallbackContainer(
|
||||||
callbacks,
|
callbacks,
|
||||||
metric=metric_fn,
|
metric=metric_fn,
|
||||||
# For old `feval` parameter, the behavior is unchanged. For the new
|
# For old `feval` parameter, the behavior is unchanged. For the new
|
||||||
@ -87,32 +89,32 @@ def _train_internal(
|
|||||||
output_margin=callable(obj) or metric_fn is feval,
|
output_margin=callable(obj) or metric_fn is feval,
|
||||||
)
|
)
|
||||||
|
|
||||||
bst = callbacks.before_training(bst)
|
bst = cb_container.before_training(bst)
|
||||||
|
|
||||||
for i in range(start_iteration, num_boost_round):
|
for i in range(start_iteration, num_boost_round):
|
||||||
if callbacks.before_iteration(bst, i, dtrain, evals):
|
if cb_container.before_iteration(bst, i, dtrain, evals):
|
||||||
break
|
break
|
||||||
bst.update(dtrain, i, obj)
|
bst.update(dtrain, i, obj)
|
||||||
if callbacks.after_iteration(bst, i, dtrain, evals):
|
if cb_container.after_iteration(bst, i, dtrain, evals):
|
||||||
break
|
break
|
||||||
|
|
||||||
bst = callbacks.after_training(bst)
|
bst = cb_container.after_training(bst)
|
||||||
|
|
||||||
if evals_result is not None:
|
if evals_result is not None:
|
||||||
evals_result.update(callbacks.history)
|
evals_result.update(cb_container.history)
|
||||||
|
|
||||||
# These should be moved into callback functions `after_training`, but until old
|
# These should be moved into callback functions `after_training`, but until old
|
||||||
# callbacks are removed, the train function is the only place for setting the
|
# callbacks are removed, the train function is the only place for setting the
|
||||||
# attributes.
|
# attributes.
|
||||||
num_parallel_tree, _ = _get_booster_layer_trees(bst)
|
num_parallel_tree, _ = _get_booster_layer_trees(bst)
|
||||||
if bst.attr('best_score') is not None:
|
if bst.attr('best_score') is not None:
|
||||||
bst.best_score = float(bst.attr('best_score'))
|
bst.best_score = float(cast(str, bst.attr('best_score')))
|
||||||
bst.best_iteration = int(bst.attr('best_iteration'))
|
bst.best_iteration = int(cast(str, bst.attr('best_iteration')))
|
||||||
# num_class is handled internally
|
# num_class is handled internally
|
||||||
bst.set_attr(
|
bst.set_attr(
|
||||||
best_ntree_limit=str((bst.best_iteration + 1) * num_parallel_tree)
|
best_ntree_limit=str((bst.best_iteration + 1) * num_parallel_tree)
|
||||||
)
|
)
|
||||||
bst.best_ntree_limit = int(bst.attr("best_ntree_limit"))
|
bst.best_ntree_limit = int(cast(str, bst.attr("best_ntree_limit")))
|
||||||
else:
|
else:
|
||||||
# Due to compatibility with version older than 1.4, these attributes are added
|
# Due to compatibility with version older than 1.4, these attributes are added
|
||||||
# to Python object even if early stopping is not used.
|
# to Python object even if early stopping is not used.
|
||||||
@ -126,35 +128,32 @@ def _train_internal(
|
|||||||
return bst.copy()
|
return bst.copy()
|
||||||
|
|
||||||
|
|
||||||
@_deprecate_positional_args
|
|
||||||
def train(
|
def train(
|
||||||
params,
|
params: Dict[str, Any],
|
||||||
dtrain,
|
dtrain: DMatrix,
|
||||||
num_boost_round=10,
|
num_boost_round: int = 10,
|
||||||
*,
|
evals: Optional[Sequence[Tuple[DMatrix, str]]] = None,
|
||||||
evals=(),
|
|
||||||
obj: Optional[Objective] = None,
|
obj: Optional[Objective] = None,
|
||||||
feval=None,
|
feval: Optional[Metric] = None,
|
||||||
maximize=None,
|
maximize: Optional[bool] = None,
|
||||||
early_stopping_rounds=None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
evals_result=None,
|
evals_result: callback.TrainingCallback.EvalsLog = None,
|
||||||
verbose_eval=True,
|
verbose_eval: Optional[Union[bool, int]] = True,
|
||||||
xgb_model=None,
|
xgb_model: Optional[Union[str, os.PathLike, Booster, bytearray]] = None,
|
||||||
callbacks=None,
|
callbacks: Optional[Sequence[callback.TrainingCallback]] = None,
|
||||||
custom_metric: Optional[Metric] = None,
|
custom_metric: Optional[Metric] = None,
|
||||||
):
|
) -> Booster:
|
||||||
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
|
|
||||||
"""Train a booster with given parameters.
|
"""Train a booster with given parameters.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
params : dict
|
params :
|
||||||
Booster params.
|
Booster params.
|
||||||
dtrain : DMatrix
|
dtrain :
|
||||||
Data to be trained.
|
Data to be trained.
|
||||||
num_boost_round: int
|
num_boost_round :
|
||||||
Number of boosting iterations.
|
Number of boosting iterations.
|
||||||
evals: list of pairs (DMatrix, string)
|
evals :
|
||||||
List of validation sets for which metrics will evaluated during training.
|
List of validation sets for which metrics will evaluated during training.
|
||||||
Validation metrics will help us track the performance of the model.
|
Validation metrics will help us track the performance of the model.
|
||||||
obj
|
obj
|
||||||
@ -166,7 +165,7 @@ def train(
|
|||||||
Use `custom_metric` instead.
|
Use `custom_metric` instead.
|
||||||
maximize : bool
|
maximize : bool
|
||||||
Whether to maximize feval.
|
Whether to maximize feval.
|
||||||
early_stopping_rounds: int
|
early_stopping_rounds :
|
||||||
Activates early stopping. Validation metric needs to improve at least once in
|
Activates early stopping. Validation metric needs to improve at least once in
|
||||||
every **early_stopping_rounds** round(s) to continue training.
|
every **early_stopping_rounds** round(s) to continue training.
|
||||||
Requires at least one item in **evals**.
|
Requires at least one item in **evals**.
|
||||||
@ -178,7 +177,7 @@ def train(
|
|||||||
**params**, the last metric will be used for early stopping.
|
**params**, the last metric will be used for early stopping.
|
||||||
If early stopping occurs, the model will have two additional fields:
|
If early stopping occurs, the model will have two additional fields:
|
||||||
``bst.best_score``, ``bst.best_iteration``.
|
``bst.best_score``, ``bst.best_iteration``.
|
||||||
evals_result: dict
|
evals_result :
|
||||||
This dictionary stores the evaluation results of all the items in watchlist.
|
This dictionary stores the evaluation results of all the items in watchlist.
|
||||||
|
|
||||||
Example: with a watchlist containing
|
Example: with a watchlist containing
|
||||||
@ -191,7 +190,7 @@ def train(
|
|||||||
{'train': {'logloss': ['0.48253', '0.35953']},
|
{'train': {'logloss': ['0.48253', '0.35953']},
|
||||||
'eval': {'logloss': ['0.480385', '0.357756']}}
|
'eval': {'logloss': ['0.480385', '0.357756']}}
|
||||||
|
|
||||||
verbose_eval : bool or int
|
verbose_eval :
|
||||||
Requires at least one item in **evals**.
|
Requires at least one item in **evals**.
|
||||||
If **verbose_eval** is True then the evaluation metric on the validation set is
|
If **verbose_eval** is True then the evaluation metric on the validation set is
|
||||||
printed at each boosting stage.
|
printed at each boosting stage.
|
||||||
@ -200,9 +199,9 @@ def train(
|
|||||||
/ the boosting stage found by using **early_stopping_rounds** is also printed.
|
/ the boosting stage found by using **early_stopping_rounds** is also printed.
|
||||||
Example: with ``verbose_eval=4`` and at least one item in **evals**, an evaluation metric
|
Example: with ``verbose_eval=4`` and at least one item in **evals**, an evaluation metric
|
||||||
is printed every 4 boosting stages, instead of every boosting stage.
|
is printed every 4 boosting stages, instead of every boosting stage.
|
||||||
xgb_model : file name of stored xgb model or 'Booster' instance
|
xgb_model :
|
||||||
Xgb model to be loaded before training (allows training continuation).
|
Xgb model to be loaded before training (allows training continuation).
|
||||||
callbacks : list of callback functions
|
callbacks :
|
||||||
List of callback functions that are applied at end of each iteration.
|
List of callback functions that are applied at end of each iteration.
|
||||||
It is possible to use predefined callbacks by using
|
It is possible to use predefined callbacks by using
|
||||||
:ref:`Callback API <callback_api>`.
|
:ref:`Callback API <callback_api>`.
|
||||||
|
|||||||
@ -9,7 +9,9 @@ rng = np.random.RandomState(1994)
|
|||||||
|
|
||||||
|
|
||||||
class TestInteractionConstraints:
|
class TestInteractionConstraints:
|
||||||
def run_interaction_constraints(self, tree_method, feature_names=None, interaction_constraints='[[0, 1]]'):
|
def run_interaction_constraints(
|
||||||
|
self, tree_method, feature_names=None, interaction_constraints='[[0, 1]]'
|
||||||
|
):
|
||||||
x1 = np.random.normal(loc=1.0, scale=1.0, size=1000)
|
x1 = np.random.normal(loc=1.0, scale=1.0, size=1000)
|
||||||
x2 = np.random.normal(loc=1.0, scale=1.0, size=1000)
|
x2 = np.random.normal(loc=1.0, scale=1.0, size=1000)
|
||||||
x3 = np.random.choice([1, 2, 3], size=1000, replace=True)
|
x3 = np.random.choice([1, 2, 3], size=1000, replace=True)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user