Move skl eval_metric and early_stopping rounds to model params. (#6751)

A new parameter `custom_metric` is added to `train` and `cv` to distinguish the behaviour from the old `feval`.  And `feval` is deprecated.  The new `custom_metric` receives transformed prediction when the built-in objective is used.  This enables XGBoost to use cost functions from other libraries like scikit-learn directly without going through the definition of the link function.

`eval_metric` and `early_stopping_rounds` in sklearn interface are moved from `fit` to `__init__` and is now saved as part of the scikit-learn model.  The old ones in `fit` function are now deprecated. The new `eval_metric` in `__init__` has the same new behaviour as `custom_metric`.

Added more detailed documents for the behaviour of custom objective and metric.
This commit is contained in:
Jiaming Yuan
2021-10-28 17:20:20 +08:00
committed by GitHub
parent 6b074add66
commit 45aef75cca
13 changed files with 685 additions and 190 deletions

View File

@@ -103,10 +103,13 @@ class CallbackContainer:
EvalsLog = TrainingCallback.EvalsLog
def __init__(self,
callbacks: List[TrainingCallback],
metric: Callable = None,
is_cv: bool = False):
def __init__(
self,
callbacks: List[TrainingCallback],
metric: Callable = None,
output_margin: bool = True,
is_cv: bool = False
) -> None:
self.callbacks = set(callbacks)
if metric is not None:
msg = 'metric must be callable object for monitoring. For ' + \
@@ -115,6 +118,7 @@ class CallbackContainer:
assert callable(metric), msg
self.metric = metric
self.history: TrainingCallback.EvalsLog = collections.OrderedDict()
self._output_margin = output_margin
self.is_cv = is_cv
if self.is_cv:
@@ -171,7 +175,7 @@ class CallbackContainer:
def after_iteration(self, model, epoch, dtrain, evals) -> bool:
'''Function called after training iteration.'''
if self.is_cv:
scores = model.eval(epoch, self.metric)
scores = model.eval(epoch, self.metric, self._output_margin)
scores = _aggcv(scores)
self.aggregated_cv = scores
self._update_history(scores, epoch)
@@ -179,7 +183,7 @@ class CallbackContainer:
evals = [] if evals is None else evals
for _, name in evals:
assert name.find('-') == -1, 'Dataset name should not contain `-`'
score = model.eval_set(evals, epoch, self.metric)
score = model.eval_set(evals, epoch, self.metric, self._output_margin)
score = score.split()[1:] # into datasets
# split up `test-error:0.1234`
score = [tuple(s.split(':')) for s in score]

View File

@@ -1700,7 +1700,7 @@ class Booster(object):
c_array(ctypes.c_float, hess),
c_bst_ulong(len(grad))))
def eval_set(self, evals, iteration=0, feval=None):
def eval_set(self, evals, iteration=0, feval=None, output_margin=True):
# pylint: disable=invalid-name
"""Evaluate a set of data.
@@ -1728,24 +1728,30 @@ class Booster(object):
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
msg = ctypes.c_char_p()
_check_call(_LIB.XGBoosterEvalOneIter(self.handle,
ctypes.c_int(iteration),
dmats, evnames,
c_bst_ulong(len(evals)),
ctypes.byref(msg)))
_check_call(
_LIB.XGBoosterEvalOneIter(
self.handle,
ctypes.c_int(iteration),
dmats,
evnames,
c_bst_ulong(len(evals)),
ctypes.byref(msg),
)
)
res = msg.value.decode() # pylint: disable=no-member
if feval is not None:
for dmat, evname in evals:
feval_ret = feval(self.predict(dmat, training=False,
output_margin=True), dmat)
feval_ret = feval(
self.predict(dmat, training=False, output_margin=output_margin), dmat
)
if isinstance(feval_ret, list):
for name, val in feval_ret:
# pylint: disable=consider-using-f-string
res += '\t%s-%s:%f' % (evname, name, val)
res += "\t%s-%s:%f" % (evname, name, val)
else:
name, val = feval_ret
# pylint: disable=consider-using-f-string
res += '\t%s-%s:%f' % (evname, name, val)
res += "\t%s-%s:%f" % (evname, name, val)
return res
def eval(self, data, name='eval', iteration=0):

View File

@@ -844,6 +844,7 @@ async def _train_async(
verbose_eval: Union[int, bool],
xgb_model: Optional[Booster],
callbacks: Optional[List[TrainingCallback]],
custom_metric: Optional[Metric],
) -> Optional[TrainReturnT]:
workers = _get_workers_from_data(dtrain, evals)
_rabit_args = await _get_rabit_args(len(workers), client)
@@ -896,6 +897,7 @@ async def _train_async(
evals=local_evals,
obj=obj,
feval=feval,
custom_metric=custom_metric,
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
xgb_model=xgb_model,
@@ -942,11 +944,13 @@ async def _train_async(
return list(filter(lambda ret: ret is not None, results))[0]
@_deprecate_positional_args
def train( # pylint: disable=unused-argument
client: "distributed.Client",
params: Dict[str, Any],
dtrain: DaskDMatrix,
num_boost_round: int = 10,
*,
evals: Optional[List[Tuple[DaskDMatrix, str]]] = None,
obj: Optional[Objective] = None,
feval: Optional[Metric] = None,
@@ -954,6 +958,7 @@ def train( # pylint: disable=unused-argument
xgb_model: Optional[Booster] = None,
verbose_eval: Union[int, bool] = True,
callbacks: Optional[List[TrainingCallback]] = None,
custom_metric: Optional[Metric] = None,
) -> Any:
"""Train XGBoost model.
@@ -1647,7 +1652,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
eval_metric: Optional[Union[str, List[str], Metric]],
sample_weight_eval_set: Optional[List[_DaskCollection]],
base_margin_eval_set: Optional[List[_DaskCollection]],
early_stopping_rounds: int,
early_stopping_rounds: Optional[int],
verbose: bool,
xgb_model: Optional[Union[Booster, XGBModel]],
feature_weights: Optional[_DaskCollection],
@@ -1676,8 +1681,8 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
obj: Optional[Callable] = _objective_decorator(self.objective)
else:
obj = None
model, metric, params = self._configure_fit(
booster=xgb_model, eval_metric=eval_metric, params=params
model, metric, params, early_stopping_rounds = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds
)
results = await self.client.sync(
_train_async,
@@ -1689,7 +1694,8 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
obj=obj,
feval=metric,
feval=None,
custom_metric=metric,
verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks,
@@ -1736,7 +1742,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
eval_metric: Optional[Union[str, List[str], Metric]],
sample_weight_eval_set: Optional[List[_DaskCollection]],
base_margin_eval_set: Optional[List[_DaskCollection]],
early_stopping_rounds: int,
early_stopping_rounds: Optional[int],
verbose: bool,
xgb_model: Optional[Union[Booster, XGBModel]],
feature_weights: Optional[_DaskCollection],
@@ -1778,8 +1784,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
obj: Optional[Callable] = _objective_decorator(self.objective)
else:
obj = None
model, metric, params = self._configure_fit(
booster=xgb_model, eval_metric=eval_metric, params=params
model, metric, params, early_stopping_rounds = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds
)
results = await self.client.sync(
_train_async,
@@ -1791,7 +1797,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
obj=obj,
feval=metric,
feval=None,
custom_metric=metric,
verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks,
@@ -1832,9 +1839,14 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
base_margin: Optional[_DaskCollection],
iteration_range: Optional[Tuple[int, int]],
) -> _DaskCollection:
if self.objective == "multi:softmax":
raise ValueError(
"multi:softmax doesn't support `predict_proba`. "
"Switch to `multi:softproba` instead"
)
predts = await super()._predict_async(
data=X,
output_margin=self.objective == "multi:softmax",
output_margin=False,
validate_features=validate_features,
base_margin=base_margin,
iteration_range=iteration_range,
@@ -1903,9 +1915,9 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
""",
["estimators", "model"],
end_note="""
Note
----
For dask implementation, group is not supported, use qid instead.
.. note::
For dask implementation, group is not supported, use qid instead.
""",
)
class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
@@ -1929,7 +1941,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
eval_group: Optional[List[_DaskCollection]],
eval_qid: Optional[List[_DaskCollection]],
eval_metric: Optional[Union[str, List[str], Metric]],
early_stopping_rounds: int,
early_stopping_rounds: Optional[int],
verbose: bool,
xgb_model: Optional[Union[XGBModel, Booster]],
feature_weights: Optional[_DaskCollection],
@@ -1963,8 +1975,8 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
raise ValueError(
"Custom evaluation metric is not yet supported for XGBRanker."
)
model, metric, params = self._configure_fit(
booster=xgb_model, eval_metric=eval_metric, params=params
model, metric, params, early_stopping_rounds = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds
)
results = await self.client.sync(
_train_async,
@@ -1976,7 +1988,8 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
obj=None,
feval=metric,
feval=None,
custom_metric=metric,
verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks,

View File

@@ -89,6 +89,19 @@ def _objective_decorator(
return inner
def _metric_decorator(func: Callable) -> Metric:
"""Decorate a metric function from sklearn.
Converts an metric function that uses the typical sklearn metric signature so that it
is compatible with :py:func:`train`
"""
def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]:
y_true = dmatrix.get_label()
return func.__name__, func(y_true, y_score)
return inner
__estimator_doc = '''
n_estimators : int
Number of gradient boosted trees. Equivalent to number of boosting
@@ -183,6 +196,66 @@ __model_doc = f'''
Experimental support for categorical data. Do not set to true unless you are
interested in development. Only valid when `gpu_hist` and dataframe are used.
eval_metric : Optional[Union[str, List[str], Callable]]
.. versionadded:: 1.5.1
Metric used for monitoring the training result and early stopping. It can be a
string or list of strings as names of predefined metric in XGBoost (See
doc/parameter.rst), one of the metrics in :py:mod:`sklearn.metrics`, or any other
user defined metric that looks like `sklearn.metrics`.
If custom objective is also provided, then custom metric should implement the
corresponding reverse link function.
Unlike the `scoring` parameter commonly used in scikit-learn, when a callable
object is provided, it's assumed to be a cost function and by default XGBoost will
minimize the result during early stopping.
For advanced usage on Early stopping like directly choosing to maximize instead of
minimize, see :py:obj:`xgboost.callback.EarlyStopping`.
See `Custom Objective and Evaluation Metric
<https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html>`_ for
more.
.. note::
This parameter replaces `eval_metric` in :py:meth:`fit` method. The old one
receives un-transformed prediction regardless of whether custom objective is
being used.
.. code-block:: python
from sklearn.datasets import load_diabetes
from sklearn.metrics import mean_absolute_error
X, y = load_diabetes(return_X_y=True)
reg = xgb.XGBRegressor(
tree_method="hist",
eval_metric=mean_absolute_error,
)
reg.fit(X, y, eval_set=[(X, y)])
early_stopping_rounds : Optional[int]
.. versionadded:: 1.5.1
Activates early stopping. Validation metric needs to improve at least once in
every **early_stopping_rounds** round(s) to continue training. Requires at least
one item in **eval_set** in :py:meth:`xgboost.sklearn.XGBModel.fit`.
The method returns the model from the last iteration (not the best one). If
there's more than one item in **eval_set**, the last entry will be used for early
stopping. If there's more than one metric in **eval_metric**, the last metric
will be used for early stopping.
If early stopping occurs, the model will have three additional fields:
``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``.
.. note::
This parameter replaces `early_stopping_rounds` in :py:meth:`fit` method.
kwargs : dict, optional
Keyword arguments for XGBoost Booster object. Full documentation of
parameters can be found here:
@@ -397,6 +470,8 @@ class XGBModel(XGBModelBase):
validate_parameters: Optional[bool] = None,
predictor: Optional[str] = None,
enable_categorical: bool = False,
eval_metric: Optional[Union[str, List[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None,
**kwargs: Any
) -> None:
if not SKLEARN_INSTALLED:
@@ -433,6 +508,8 @@ class XGBModel(XGBModelBase):
self.validate_parameters = validate_parameters
self.predictor = predictor
self.enable_categorical = enable_categorical
self.eval_metric = eval_metric
self.early_stopping_rounds = early_stopping_rounds
if kwargs:
self.kwargs = kwargs
@@ -543,8 +620,13 @@ class XGBModel(XGBModelBase):
params = self.get_params()
# Parameters that should not go into native learner.
wrapper_specific = {
'importance_type', 'kwargs', 'missing', 'n_estimators', 'use_label_encoder',
"enable_categorical"
"importance_type",
"kwargs",
"missing",
"n_estimators",
"use_label_encoder",
"enable_categorical",
"early_stopping_rounds",
}
filtered = {}
for k, v in params.items():
@@ -629,32 +711,80 @@ class XGBModel(XGBModelBase):
load_model.__doc__ = f"""{Booster.load_model.__doc__}"""
# pylint: disable=too-many-branches
def _configure_fit(
self,
booster: Optional[Union[Booster, "XGBModel", str]],
eval_metric: Optional[Union[Callable, str, List[str]]],
params: Dict[str, Any],
) -> Tuple[Optional[Union[Booster, str]], Optional[Metric], Dict[str, Any]]:
# pylint: disable=protected-access, no-self-use
early_stopping_rounds: Optional[int],
) -> Tuple[
Optional[Union[Booster, str, "XGBModel"]],
Optional[Metric],
Dict[str, Any],
Optional[int],
]:
"""Configure parameters for :py:meth:`fit`."""
if isinstance(booster, XGBModel):
# Handle the case when xgb_model is a sklearn model object
model: Optional[Union[Booster, str]] = booster._Booster
model: Optional[Union[Booster, str]] = booster.get_booster()
else:
model = booster
feval = eval_metric if callable(eval_metric) else None
def _deprecated(parameter: str) -> None:
warnings.warn(
f"`{parameter}` in `fit` method is deprecated for better compatibility "
f"with scikit-learn, use `{parameter}` in constructor or`set_params` "
"instead.",
UserWarning,
)
def _duplicated(parameter: str) -> None:
raise ValueError(
f"2 different `{parameter}` are provided. Use the one in constructor "
"or `set_params` instead."
)
# Configure evaluation metric.
if eval_metric is not None:
if callable(eval_metric):
eval_metric = None
_deprecated("eval_metric")
if self.eval_metric is not None and eval_metric is not None:
_duplicated("eval_metric")
# - track where does the evaluation metric come from
if self.eval_metric is not None:
from_fit = False
eval_metric = self.eval_metric
else:
from_fit = True
# - configure callable evaluation metric
metric: Optional[Metric] = None
if eval_metric is not None:
if callable(eval_metric) and from_fit:
# No need to wrap the evaluation function for old parameter.
metric = eval_metric
elif callable(eval_metric):
# Parameter from constructor or set_params
metric = _metric_decorator(eval_metric)
else:
params.update({"eval_metric": eval_metric})
# Configure early_stopping_rounds
if early_stopping_rounds is not None:
_deprecated("early_stopping_rounds")
if early_stopping_rounds is not None and self.early_stopping_rounds is not None:
_duplicated("early_stopping_rounds")
early_stopping_rounds = (
self.early_stopping_rounds
if self.early_stopping_rounds is not None
else early_stopping_rounds
)
if self.enable_categorical and params.get("tree_method", None) != "gpu_hist":
raise ValueError(
"Experimental support for categorical data is not implemented for"
" current tree method yet."
)
return model, feval, params
return model, metric, params, early_stopping_rounds
def _set_evaluation_result(self, evals_result: TrainingCallback.EvalsLog) -> None:
if evals_result:
@@ -702,31 +832,15 @@ class XGBModel(XGBModelBase):
A list of (X, y) tuple pairs to use as validation sets, for which
metrics will be computed.
Validation metrics will help us track the performance of the model.
eval_metric :
If a str, should be a built-in evaluation metric to use. See doc/parameter.rst.
If a list of str, should be the list of multiple built-in evaluation metrics
to use.
eval_metric : str, list of str, or callable, optional
.. deprecated:: 1.5.1
Use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead.
If callable, a custom evaluation metric. The call signature is
``func(y_predicted, y_true)`` where ``y_true`` will be a DMatrix object such
that you may need to call the ``get_label`` method. It must return a str,
value pair where the str is a name for the evaluation and value is the value
of the evaluation function. The callable custom objective is always minimized.
early_stopping_rounds :
Activates early stopping. Validation metric needs to improve at least once in
every **early_stopping_rounds** round(s) to continue training.
Requires at least one item in **eval_set**.
The method returns the model from the last iteration (not the best one).
If there's more than one item in **eval_set**, the last entry will be used
for early stopping.
If there's more than one metric in **eval_metric**, the last metric will be
used for early stopping.
If early stopping occurs, the model will have three additional fields:
``clf.best_score``, ``clf.best_iteration``.
early_stopping_rounds : int
.. deprecated:: 1.5.1
Use `early_stopping_rounds` in :py:meth:`__init__` or
:py:meth:`set_params` instead.
verbose :
If `verbose` and an evaluation set is used, writes the evaluation metric
measured on the validation set to stderr.
@@ -783,7 +897,9 @@ class XGBModel(XGBModelBase):
else:
obj = None
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
model, metric, params, early_stopping_rounds = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds
)
self._Booster = train(
params,
train_dmatrix,
@@ -792,7 +908,7 @@ class XGBModel(XGBModelBase):
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result,
obj=obj,
feval=feval,
custom_metric=metric,
verbose_eval=verbose,
xgb_model=model,
callbacks=callbacks,
@@ -1185,12 +1301,14 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
obj = None
if self.n_classes_ > 2:
# Switch to using a multiclass objective in the underlying
# XGB instance
params["objective"] = "multi:softprob"
# Switch to using a multiclass objective in the underlying XGB instance
if params.get("objective", None) != "multi:softmax":
params["objective"] = "multi:softprob"
params["num_class"] = self.n_classes_
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
model, metric, params, early_stopping_rounds = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds
)
train_dmatrix, evals = _wrap_evaluation_matrices(
missing=self.missing,
X=X,
@@ -1217,7 +1335,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result,
obj=obj,
feval=feval,
custom_metric=metric,
verbose_eval=verbose,
xgb_model=model,
callbacks=callbacks,
@@ -1304,12 +1422,19 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
"""
# custom obj: Do nothing as we don't know what to do.
# softprob: Do nothing, output is proba.
# softmax: Use output margin to remove the argmax in PredTransform.
# softmax: Unsupported by predict_proba()
# binary:logistic: Expand the prob vector into 2-class matrix after predict.
# binary:logitraw: Unsupported by predict_proba()
if self.objective == "multi:softmax":
# We need to run a Python implementation of softmax for it. Just ask user to
# use softprob since XGBoost's implementation has mitigation for floating
# point overflow. No need to reinvent the wheel.
raise ValueError(
"multi:softmax doesn't support `predict_proba`. "
"Switch to `multi:softproba` instead"
)
class_probs = super().predict(
X=X,
output_margin=self.objective == "multi:softmax",
ntree_limit=ntree_limit,
validate_features=validate_features,
base_margin=base_margin,
@@ -1325,8 +1450,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
If **eval_set** is passed to the `fit` function, you can call
``evals_result()`` to get evaluation results for all passed **eval_sets**.
When **eval_metric** is also passed to the `fit` function, the
**evals_result** will contain the **eval_metrics** passed to the `fit` function.
When **eval_metric** is also passed as a parameter, the **evals_result** will
contain the **eval_metric** passed to the `fit` function.
Returns
-------
@@ -1337,13 +1463,14 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
.. code-block:: python
param_dist = {'objective':'binary:logistic', 'n_estimators':2}
param_dist = {
'objective':'binary:logistic', 'n_estimators':2, eval_metric="logloss"
}
clf = xgb.XGBClassifier(**param_dist)
clf.fit(X_train, y_train,
eval_set=[(X_train, y_train), (X_test, y_test)],
eval_metric='logloss',
verbose=True)
evals_result = clf.evals_result()
@@ -1354,6 +1481,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
{'validation_0': {'logloss': ['0.604835', '0.531479']},
'validation_1': {'logloss': ['0.41965', '0.17686']}}
"""
if self.evals_result_:
evals_result = self.evals_result_
@@ -1386,6 +1514,7 @@ class XGBRFClassifier(XGBClassifier):
colsample_bynode=colsample_bynode,
reg_lambda=reg_lambda,
**kwargs)
_check_rf_callback(self.early_stopping_rounds, None)
def get_xgb_params(self) -> Dict[str, Any]:
params = super().get_xgb_params()
@@ -1457,6 +1586,7 @@ class XGBRFRegressor(XGBRegressor):
reg_lambda=reg_lambda,
**kwargs
)
_check_rf_callback(self.early_stopping_rounds, None)
def get_xgb_params(self) -> Dict[str, Any]:
params = super().get_xgb_params()
@@ -1495,15 +1625,15 @@ class XGBRFRegressor(XGBRegressor):
'Implementation of the Scikit-Learn API for XGBoost Ranking.',
['estimators', 'model'],
end_note='''
Note
----
A custom objective function is currently not supported by XGBRanker.
Likewise, a custom metric function is not supported either.
.. note::
Note
----
Query group information is required for ranking tasks by either using the `group`
parameter or `qid` parameter in `fit` method.
A custom objective function is currently not supported by XGBRanker.
Likewise, a custom metric function is not supported either.
.. note::
Query group information is required for ranking tasks by either using the
`group` parameter or `qid` parameter in `fit` method.
Before fitting the model, your data need to be sorted by query group. When fitting
the model, you need to provide an additional array that contains the size of each
@@ -1605,22 +1735,16 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
eval_qid :
A list in which ``eval_qid[i]`` is the array containing query ID of ``i``-th
pair in **eval_set**.
eval_metric :
If a str, should be a built-in evaluation metric to use. See
doc/parameter.rst.
If a list of str, should be the list of multiple built-in evaluation metrics
to use. The custom evaluation metric is not yet supported for the ranker.
early_stopping_rounds :
Activates early stopping. Validation metric needs to improve at least once in
every **early_stopping_rounds** round(s) to continue training. Requires at
least one item in **eval_set**.
The method returns the model from the last iteration (not the best one). If
there's more than one item in **eval_set**, the last entry will be used for
early stopping.
If there's more than one metric in **eval_metric**, the last metric will be
used for early stopping.
If early stopping occurs, the model will have three additional fields:
``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``.
eval_metric : str, list of str, optional
.. deprecated:: 1.5.1
use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead.
early_stopping_rounds : int
.. deprecated:: 1.5.1
use `early_stopping_rounds` in :py:meth:`__init__` or
:py:meth:`set_params` instead.
verbose :
If `verbose` and an evaluation set is used, writes the evaluation metric
measured on the validation set to stderr.
@@ -1685,8 +1809,10 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
evals_result: TrainingCallback.EvalsLog = {}
params = self.get_xgb_params()
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
if callable(feval):
model, metric, params, early_stopping_rounds = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds
)
if callable(metric):
raise ValueError(
'Custom evaluation metric is not yet supported for XGBRanker.'
)
@@ -1696,7 +1822,8 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
self.n_estimators,
early_stopping_rounds=early_stopping_rounds,
evals=evals,
evals_result=evals_result, feval=feval,
evals_result=evals_result,
custom_metric=metric,
verbose_eval=verbose, xgb_model=model,
callbacks=callbacks
)

View File

@@ -4,9 +4,12 @@
"""Training Library containing training routines."""
import copy
from typing import Optional, List
import warnings
import numpy as np
from .core import Booster, XGBoostError, _get_booster_layer_trees
from .core import _deprecate_positional_args
from .core import Objective, Metric
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
from . import callback
@@ -22,21 +25,48 @@ def _assert_new_callback(callbacks: Optional[List[callback.TrainingCallback]]) -
)
def _train_internal(params, dtrain,
num_boost_round=10, evals=(),
obj=None, feval=None,
xgb_model=None, callbacks=None,
evals_result=None, maximize=None,
verbose_eval=None, early_stopping_rounds=None):
def _configure_custom_metric(
feval: Optional[Metric], custom_metric: Optional[Metric]
) -> Optional[Metric]:
if feval is not None:
link = "https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html"
warnings.warn(
"`feval` is deprecated, use `custom_metric` instead. They have "
"different behavior when custom objective is also used."
f"See {link} for details on the `custom_metric`."
)
if feval is not None and custom_metric is not None:
raise ValueError(
"Bost `feval` and `custom_metric` are supplied. Use `custom_metric` instead."
)
eval_metric = custom_metric if custom_metric is not None else feval
return eval_metric
def _train_internal(
params,
dtrain,
num_boost_round=10,
evals=(),
obj=None,
feval=None,
custom_metric=None,
xgb_model=None,
callbacks=None,
evals_result=None,
maximize=None,
verbose_eval=None,
early_stopping_rounds=None,
):
"""internal training function"""
callbacks = [] if callbacks is None else copy.copy(callbacks)
metric_fn = _configure_custom_metric(feval, custom_metric)
evals = list(evals)
bst = Booster(params, [dtrain] + [d[0] for d in evals])
if xgb_model is not None:
bst = Booster(params, [dtrain] + [d[0] for d in evals],
model_file=xgb_model)
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
start_iteration = 0
@@ -48,7 +78,14 @@ def _train_internal(params, dtrain,
callbacks.append(
callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
)
callbacks = callback.CallbackContainer(callbacks, metric=feval)
callbacks = callback.CallbackContainer(
callbacks,
metric=metric_fn,
# For old `feval` parameter, the behavior is unchanged. For the new
# `custom_metric`, it will receive proper prediction result when custom objective
# is not used.
output_margin=callable(obj) or metric_fn is feval,
)
bst = callbacks.before_training(bst)
@@ -89,9 +126,23 @@ def _train_internal(params, dtrain,
return bst.copy()
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
maximize=None, early_stopping_rounds=None, evals_result=None,
verbose_eval=True, xgb_model=None, callbacks=None):
@_deprecate_positional_args
def train(
params,
dtrain,
num_boost_round=10,
*,
evals=(),
obj: Optional[Objective] = None,
feval=None,
maximize=None,
early_stopping_rounds=None,
evals_result=None,
verbose_eval=True,
xgb_model=None,
callbacks=None,
custom_metric: Optional[Metric] = None,
):
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
"""Train a booster with given parameters.
@@ -106,10 +157,13 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
evals: list of pairs (DMatrix, string)
List of validation sets for which metrics will evaluated during training.
Validation metrics will help us track the performance of the model.
obj : function
Customized objective function.
feval : function
Customized evaluation function.
obj
Custom objective function. See `Custom Objective
<https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html>`_ for
details.
feval :
.. deprecated:: 1.5.1
Use `custom_metric` instead.
maximize : bool
Whether to maximize feval.
early_stopping_rounds: int
@@ -158,23 +212,37 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
[xgb.callback.LearningRateScheduler(custom_rates)]
custom_metric:
.. versionadded 1.5.1
Custom metric function. See `Custom Metric
<https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html>`_ for
details.
Returns
-------
Booster : a trained booster model
"""
bst = _train_internal(params, dtrain,
num_boost_round=num_boost_round,
evals=evals,
obj=obj, feval=feval,
xgb_model=xgb_model, callbacks=callbacks,
verbose_eval=verbose_eval,
evals_result=evals_result,
maximize=maximize,
early_stopping_rounds=early_stopping_rounds)
bst = _train_internal(
params,
dtrain,
num_boost_round=num_boost_round,
evals=evals,
obj=obj,
feval=feval,
xgb_model=xgb_model,
callbacks=callbacks,
verbose_eval=verbose_eval,
evals_result=evals_result,
maximize=maximize,
early_stopping_rounds=early_stopping_rounds,
custom_metric=custom_metric,
)
return bst
class CVPack(object):
class CVPack:
""""Auxiliary datastruct to hold one fold of CV."""
def __init__(self, dtrain, dtest, param):
""""Initialize the CVPack"""
@@ -192,9 +260,9 @@ class CVPack(object):
""""Update the boosters for one iteration"""
self.bst.update(self.dtrain, iteration, fobj)
def eval(self, iteration, feval):
def eval(self, iteration, feval, output_margin):
""""Evaluate the CVPack for one iteration."""
return self.bst.eval_set(self.watchlist, iteration, feval)
return self.bst.eval_set(self.watchlist, iteration, feval, output_margin)
class _PackedBooster:
@@ -206,9 +274,9 @@ class _PackedBooster:
for fold in self.cvfolds:
fold.update(iteration, obj)
def eval(self, iteration, feval):
def eval(self, iteration, feval, output_margin):
'''Iterate through folds for eval'''
result = [f.eval(iteration, feval) for f in self.cvfolds]
result = [f.eval(iteration, feval, output_margin) for f in self.cvfolds]
return result
def set_attr(self, **kwargs):
@@ -345,9 +413,10 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None,
metrics=(), obj=None, feval=None, maximize=None, early_stopping_rounds=None,
metrics=(), obj: Optional[Objective] = None,
feval=None, maximize=None, early_stopping_rounds=None,
fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True,
seed=0, callbacks=None, shuffle=True):
seed=0, callbacks=None, shuffle=True, custom_metric: Optional[Metric] = None):
# pylint: disable = invalid-name
"""Cross-validation with given parameters.
@@ -372,10 +441,15 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
indices to be used as the testing samples for the ``n`` th fold.
metrics : string or list of strings
Evaluation metrics to be watched in CV.
obj : function
Custom objective function.
obj :
Custom objective function. See `Custom Objective
<https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html>`_ for
details.
feval : function
Custom evaluation function.
.. deprecated:: 1.5.1
Use `custom_metric` instead.
maximize : bool
Whether to maximize feval.
early_stopping_rounds: int
@@ -412,6 +486,13 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
[xgb.callback.LearningRateScheduler(custom_rates)]
shuffle : bool
Shuffle data before creating folds.
custom_metric :
.. versionadded 1.5.1
Custom metric function. See `Custom Metric
<https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html>`_ for
details.
Returns
-------
@@ -443,6 +524,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc,
stratified, folds, shuffle)
metric_fn = _configure_custom_metric(feval, custom_metric)
# setup callbacks
callbacks = [] if callbacks is None else callbacks
_assert_new_callback(callbacks)
@@ -456,7 +539,12 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
callbacks.append(
callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
)
callbacks = callback.CallbackContainer(callbacks, metric=feval, is_cv=True)
callbacks = callback.CallbackContainer(
callbacks,
metric=metric_fn,
is_cv=True,
output_margin=callable(obj) or metric_fn is feval,
)
booster = _PackedBooster(cvfolds)
callbacks.before_training(booster)