Move skl eval_metric and early_stopping rounds to model params. (#6751)

A new parameter `custom_metric` is added to `train` and `cv` to distinguish the behaviour from the old `feval`.  And `feval` is deprecated.  The new `custom_metric` receives transformed prediction when the built-in objective is used.  This enables XGBoost to use cost functions from other libraries like scikit-learn directly without going through the definition of the link function.

`eval_metric` and `early_stopping_rounds` in sklearn interface are moved from `fit` to `__init__` and is now saved as part of the scikit-learn model.  The old ones in `fit` function are now deprecated. The new `eval_metric` in `__init__` has the same new behaviour as `custom_metric`.

Added more detailed documents for the behaviour of custom objective and metric.
This commit is contained in:
Jiaming Yuan 2021-10-28 17:20:20 +08:00 committed by GitHub
parent 6b074add66
commit 45aef75cca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 685 additions and 190 deletions

View File

@ -144,7 +144,7 @@ def py_rmsle(dtrain: xgb.DMatrix, dtest: xgb.DMatrix) -> Dict:
dtrain=dtrain, dtrain=dtrain,
num_boost_round=kBoostRound, num_boost_round=kBoostRound,
obj=squared_log, obj=squared_log,
feval=rmsle, custom_metric=rmsle,
evals=[(dtrain, 'dtrain'), (dtest, 'dtest')], evals=[(dtrain, 'dtrain'), (dtest, 'dtest')],
evals_result=results) evals_result=results)

View File

@ -3,6 +3,9 @@ only applicable after (excluding) XGBoost 1.0.0, as before this version XGBoost
returns transformed prediction for multi-class objective function. More returns transformed prediction for multi-class objective function. More
details in comments. details in comments.
See https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html for detailed
tutorial and notes.
''' '''
import numpy as np import numpy as np
@ -95,7 +98,12 @@ def predict(booster: xgb.Booster, X):
def merror(predt: np.ndarray, dtrain: xgb.DMatrix): def merror(predt: np.ndarray, dtrain: xgb.DMatrix):
y = dtrain.get_label() y = dtrain.get_label()
# Like custom objective, the predt is untransformed leaf weight # Like custom objective, the predt is untransformed leaf weight when custom objective
# is provided.
# With the use of `custom_metric` parameter in train function, custom metric receives
# raw input only when custom objective is also being used. Otherwise custom metric
# will receive transformed prediction.
assert predt.shape == (kRows, kClasses) assert predt.shape == (kRows, kClasses)
out = np.zeros(kRows) out = np.zeros(kRows)
for r in range(predt.shape[0]): for r in range(predt.shape[0]):
@ -134,7 +142,7 @@ def main(args):
m, m,
num_boost_round=kRounds, num_boost_round=kRounds,
obj=softprob_obj, obj=softprob_obj,
feval=merror, custom_metric=merror,
evals_result=custom_results, evals_result=custom_results,
evals=[(m, 'train')]) evals=[(m, 'train')])
@ -143,6 +151,7 @@ def main(args):
native_results = {} native_results = {}
# Use the same objective function defined in XGBoost. # Use the same objective function defined in XGBoost.
booster_native = xgb.train({'num_class': kClasses, booster_native = xgb.train({'num_class': kClasses,
"objective": "multi:softmax",
'eval_metric': 'merror'}, 'eval_metric': 'merror'},
m, m,
num_boost_round=kRounds, num_boost_round=kRounds,

View File

@ -2,6 +2,16 @@
Custom Objective and Evaluation Metric Custom Objective and Evaluation Metric
###################################### ######################################
**Contents**
.. contents::
:backlinks: none
:local:
********
Overview
********
XGBoost is designed to be an extensible library. One way to extend it is by providing our XGBoost is designed to be an extensible library. One way to extend it is by providing our
own objective function for training and corresponding metric for performance monitoring. own objective function for training and corresponding metric for performance monitoring.
This document introduces implementing a customized elementwise evaluation metric and This document introduces implementing a customized elementwise evaluation metric and
@ -11,12 +21,8 @@ concepts should be readily applicable to other language bindings.
.. note:: .. note::
* The ranking task does not support customized functions. * The ranking task does not support customized functions.
* The customized functions defined here are only applicable to single node training.
Distributed environment requires syncing with ``xgboost.rabit``, the interface is
subject to change hence beyond the scope of this tutorial.
* We also plan to improve the interface for multi-classes objective in the future.
In the following sections, we will provide a step by step walk through of implementing In the following two sections, we will provide a step by step walk through of implementing
``Squared Log Error(SLE)`` objective function: ``Squared Log Error(SLE)`` objective function:
.. math:: .. math::
@ -30,7 +36,10 @@ and its default metric ``Root Mean Squared Log Error(RMSLE)``:
Although XGBoost has native support for said functions, using it for demonstration Although XGBoost has native support for said functions, using it for demonstration
provides us the opportunity of comparing the result from our own implementation and the provides us the opportunity of comparing the result from our own implementation and the
one from XGBoost internal for learning purposes. After finishing this tutorial, we should one from XGBoost internal for learning purposes. After finishing this tutorial, we should
be able to provide our own functions for rapid experiments. be able to provide our own functions for rapid experiments. And at the end, we will
provide some notes on non-identy link function along with examples of using custom metric
and objective with `scikit-learn` interface.
with scikit-learn interface.
***************************** *****************************
Customized Objective Function Customized Objective Function
@ -125,12 +134,12 @@ We will be able to see XGBoost printing something like:
.. code-block:: none .. code-block:: none
[0] dtrain-PyRMSLE:1.37153 dtest-PyRMSLE:1.31487 [0] dtrain-PyRMSLE:1.37153 dtest-PyRMSLE:1.31487
[1] dtrain-PyRMSLE:1.26619 dtest-PyRMSLE:1.20899 [1] dtrain-PyRMSLE:1.26619 dtest-PyRMSLE:1.20899
[2] dtrain-PyRMSLE:1.17508 dtest-PyRMSLE:1.11629 [2] dtrain-PyRMSLE:1.17508 dtest-PyRMSLE:1.11629
[3] dtrain-PyRMSLE:1.09836 dtest-PyRMSLE:1.03871 [3] dtrain-PyRMSLE:1.09836 dtest-PyRMSLE:1.03871
[4] dtrain-PyRMSLE:1.03557 dtest-PyRMSLE:0.977186 [4] dtrain-PyRMSLE:1.03557 dtest-PyRMSLE:0.977186
[5] dtrain-PyRMSLE:0.985783 dtest-PyRMSLE:0.93057 [5] dtrain-PyRMSLE:0.985783 dtest-PyRMSLE:0.93057
... ...
Notice that the parameter ``disable_default_eval_metric`` is used to suppress the default metric Notice that the parameter ``disable_default_eval_metric`` is used to suppress the default metric
@ -138,11 +147,164 @@ in XGBoost.
For fully reproducible source code and comparison plots, see `custom_rmsle.py <https://github.com/dmlc/xgboost/tree/master/demo/guide-python/custom_rmsle.py>`_. For fully reproducible source code and comparison plots, see `custom_rmsle.py <https://github.com/dmlc/xgboost/tree/master/demo/guide-python/custom_rmsle.py>`_.
*********************
Reverse Link Function
*********************
****************************** When using builtin objective, the raw prediction is transformed according to the objective
Multi-class objective function function. When custom objective is provided XGBoost doesn't know its link function so the
****************************** user is responsible for making the transformation for both objective and custom evaluation
metric. For objective with identiy link like ``squared error`` this is trivial, but for
other link functions like log link or inverse link the difference is significant.
A similar demo for multi-class objective function is also available, see For the Python package, the behaviour of prediction can be controlled by the
`demo/guide-python/custom_softmax.py <https://github.com/dmlc/xgboost/tree/master/demo/guide-python/custom_softmax.py>`_ ``output_margin`` parameter in ``predict`` function. When using the ``custom_metric``
for details. parameter without a custom objective, the metric function will receive transformed
prediction since the objective is defined by XGBoost. However, when custom objective is
also provided along with that metric, then both the objective and custom metric will
recieve raw prediction. Following example provides a comparison between two different
behavior with a multi-class classification model. Firstly we define 2 different Python
metric functions implementing the same underlying metric for comparison,
`merror_with_transform` is used when custom objective is also used, otherwise the simpler
`merror` is preferred since XGBoost can perform the transformation itself.
.. code-block:: python
import xgboost as xgb
import numpy as np
def merror_with_transform(predt: np.ndarray, dtrain: xgb.DMatrix):
"""Used when custom objective is supplied."""
y = dtrain.get_label()
n_classes = predt.size // y.shape[0]
# Like custom objective, the predt is untransformed leaf weight when custom objective
# is provided.
# With the use of `custom_metric` parameter in train function, custom metric receives
# raw input only when custom objective is also being used. Otherwise custom metric
# will receive transformed prediction.
assert predt.shape == (d_train.num_row(), n_classes)
out = np.zeros(dtrain.num_row())
for r in range(predt.shape[0]):
i = np.argmax(predt[r])
out[r] = i
assert y.shape == out.shape
errors = np.zeros(dtrain.num_row())
errors[y != out] = 1.0
return 'PyMError', np.sum(errors) / dtrain.num_row()
The above function is only needed when we want to use custom objective and XGBoost doesn't
know how to transform the prediction. The normal implementation for multi-class error
function is:
.. code-block:: python
def merror(predt: np.ndarray, dtrain: xgb.DMatrix):
"""Used when there's no custom objective."""
# No need to do transform, XGBoost handles it internally.
errors = np.zeros(dtrain.num_row())
errors[y != out] = 1.0
return 'PyMError', np.sum(errors) / dtrain.num_row()
Next we need the custom softprob objective:
.. code-block:: python
def softprob_obj(predt: np.ndarray, data: xgb.DMatrix):
"""Loss function. Computing the gradient and approximated hessian (diagonal).
Reimplements the `multi:softprob` inside XGBoost.
"""
# Full implementation is available in the Python demo script linked below
...
return grad, hess
Lastly we can train the model using ``obj`` and ``custom_metric`` parameters:
.. code-block:: python
Xy = xgb.DMatrix(X, y)
booster = xgb.train(
{"num_class": kClasses, "disable_default_eval_metric": True},
m,
num_boost_round=kRounds,
obj=softprob_obj,
custom_metric=merror_with_transform,
evals_result=custom_results,
evals=[(m, "train")],
)
Or if you don't need the custom objective and just want to supply a metric that's not
available in XGBoost:
.. code-block:: python
booster = xgb.train(
{
"num_class": kClasses,
"disable_default_eval_metric": True,
"objective": "multi:softmax",
},
m,
num_boost_round=kRounds,
# Use a simpler metric implementation.
custom_metric=merror,
evals_result=custom_results,
evals=[(m, "train")],
)
We use ``multi:softmax`` to illustrate the differences of transformed prediction. With
``softprob`` the output prediction array has shape ``(n_samples, n_classes)`` while for
``softmax`` it's ``(n_samples, )``. A demo for multi-class objective function is also
available at `demo/guide-python/custom_softmax.py
<https://github.com/dmlc/xgboost/tree/master/demo/guide-python/custom_softmax.py>`_
**********************
Scikit-Learn Interface
**********************
The scikit-learn interface of XGBoost has some utilities to improve the integration with
standard scikit-learn functions. For instance, after XGBoost 1.5.1 users can use the cost
function (not scoring functions) from scikit-learn out of the box:
.. code-block:: python
from sklearn.datasets import load_diabetes
from sklearn.metrics import mean_absolute_error
X, y = load_diabetes(return_X_y=True)
reg = xgb.XGBRegressor(
tree_method="hist",
eval_metric=mean_absolute_error,
)
reg.fit(X, y, eval_set=[(X, y)])
Also, for custom objective function, users can define the objective without having to
access ``DMatrix``:
.. code-block:: python
def softprob_obj(labels: np.ndarray, predt: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
rows = labels.shape[0]
grad = np.zeros((rows, classes), dtype=float)
hess = np.zeros((rows, classes), dtype=float)
eps = 1e-6
for r in range(predt.shape[0]):
target = labels[r]
p = softmax(predt[r, :])
for c in range(predt.shape[1]):
g = p[c] - 1.0 if c == target else p[c]
h = max((2.0 * p[c] * (1.0 - p[c])).item(), eps)
grad[r, c] = g
hess[r, c] = h
grad = grad.reshape((rows * classes, 1))
hess = hess.reshape((rows * classes, 1))
return grad, hess
clf = xgb.XGBClassifier(tree_method="hist", objective=softprob_obj)

View File

@ -103,10 +103,13 @@ class CallbackContainer:
EvalsLog = TrainingCallback.EvalsLog EvalsLog = TrainingCallback.EvalsLog
def __init__(self, def __init__(
callbacks: List[TrainingCallback], self,
metric: Callable = None, callbacks: List[TrainingCallback],
is_cv: bool = False): metric: Callable = None,
output_margin: bool = True,
is_cv: bool = False
) -> None:
self.callbacks = set(callbacks) self.callbacks = set(callbacks)
if metric is not None: if metric is not None:
msg = 'metric must be callable object for monitoring. For ' + \ msg = 'metric must be callable object for monitoring. For ' + \
@ -115,6 +118,7 @@ class CallbackContainer:
assert callable(metric), msg assert callable(metric), msg
self.metric = metric self.metric = metric
self.history: TrainingCallback.EvalsLog = collections.OrderedDict() self.history: TrainingCallback.EvalsLog = collections.OrderedDict()
self._output_margin = output_margin
self.is_cv = is_cv self.is_cv = is_cv
if self.is_cv: if self.is_cv:
@ -171,7 +175,7 @@ class CallbackContainer:
def after_iteration(self, model, epoch, dtrain, evals) -> bool: def after_iteration(self, model, epoch, dtrain, evals) -> bool:
'''Function called after training iteration.''' '''Function called after training iteration.'''
if self.is_cv: if self.is_cv:
scores = model.eval(epoch, self.metric) scores = model.eval(epoch, self.metric, self._output_margin)
scores = _aggcv(scores) scores = _aggcv(scores)
self.aggregated_cv = scores self.aggregated_cv = scores
self._update_history(scores, epoch) self._update_history(scores, epoch)
@ -179,7 +183,7 @@ class CallbackContainer:
evals = [] if evals is None else evals evals = [] if evals is None else evals
for _, name in evals: for _, name in evals:
assert name.find('-') == -1, 'Dataset name should not contain `-`' assert name.find('-') == -1, 'Dataset name should not contain `-`'
score = model.eval_set(evals, epoch, self.metric) score = model.eval_set(evals, epoch, self.metric, self._output_margin)
score = score.split()[1:] # into datasets score = score.split()[1:] # into datasets
# split up `test-error:0.1234` # split up `test-error:0.1234`
score = [tuple(s.split(':')) for s in score] score = [tuple(s.split(':')) for s in score]

View File

@ -1700,7 +1700,7 @@ class Booster(object):
c_array(ctypes.c_float, hess), c_array(ctypes.c_float, hess),
c_bst_ulong(len(grad)))) c_bst_ulong(len(grad))))
def eval_set(self, evals, iteration=0, feval=None): def eval_set(self, evals, iteration=0, feval=None, output_margin=True):
# pylint: disable=invalid-name # pylint: disable=invalid-name
"""Evaluate a set of data. """Evaluate a set of data.
@ -1728,24 +1728,30 @@ class Booster(object):
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals]) dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals]) evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
msg = ctypes.c_char_p() msg = ctypes.c_char_p()
_check_call(_LIB.XGBoosterEvalOneIter(self.handle, _check_call(
ctypes.c_int(iteration), _LIB.XGBoosterEvalOneIter(
dmats, evnames, self.handle,
c_bst_ulong(len(evals)), ctypes.c_int(iteration),
ctypes.byref(msg))) dmats,
evnames,
c_bst_ulong(len(evals)),
ctypes.byref(msg),
)
)
res = msg.value.decode() # pylint: disable=no-member res = msg.value.decode() # pylint: disable=no-member
if feval is not None: if feval is not None:
for dmat, evname in evals: for dmat, evname in evals:
feval_ret = feval(self.predict(dmat, training=False, feval_ret = feval(
output_margin=True), dmat) self.predict(dmat, training=False, output_margin=output_margin), dmat
)
if isinstance(feval_ret, list): if isinstance(feval_ret, list):
for name, val in feval_ret: for name, val in feval_ret:
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
res += '\t%s-%s:%f' % (evname, name, val) res += "\t%s-%s:%f" % (evname, name, val)
else: else:
name, val = feval_ret name, val = feval_ret
# pylint: disable=consider-using-f-string # pylint: disable=consider-using-f-string
res += '\t%s-%s:%f' % (evname, name, val) res += "\t%s-%s:%f" % (evname, name, val)
return res return res
def eval(self, data, name='eval', iteration=0): def eval(self, data, name='eval', iteration=0):

View File

@ -844,6 +844,7 @@ async def _train_async(
verbose_eval: Union[int, bool], verbose_eval: Union[int, bool],
xgb_model: Optional[Booster], xgb_model: Optional[Booster],
callbacks: Optional[List[TrainingCallback]], callbacks: Optional[List[TrainingCallback]],
custom_metric: Optional[Metric],
) -> Optional[TrainReturnT]: ) -> Optional[TrainReturnT]:
workers = _get_workers_from_data(dtrain, evals) workers = _get_workers_from_data(dtrain, evals)
_rabit_args = await _get_rabit_args(len(workers), client) _rabit_args = await _get_rabit_args(len(workers), client)
@ -896,6 +897,7 @@ async def _train_async(
evals=local_evals, evals=local_evals,
obj=obj, obj=obj,
feval=feval, feval=feval,
custom_metric=custom_metric,
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval, verbose_eval=verbose_eval,
xgb_model=xgb_model, xgb_model=xgb_model,
@ -942,11 +944,13 @@ async def _train_async(
return list(filter(lambda ret: ret is not None, results))[0] return list(filter(lambda ret: ret is not None, results))[0]
@_deprecate_positional_args
def train( # pylint: disable=unused-argument def train( # pylint: disable=unused-argument
client: "distributed.Client", client: "distributed.Client",
params: Dict[str, Any], params: Dict[str, Any],
dtrain: DaskDMatrix, dtrain: DaskDMatrix,
num_boost_round: int = 10, num_boost_round: int = 10,
*,
evals: Optional[List[Tuple[DaskDMatrix, str]]] = None, evals: Optional[List[Tuple[DaskDMatrix, str]]] = None,
obj: Optional[Objective] = None, obj: Optional[Objective] = None,
feval: Optional[Metric] = None, feval: Optional[Metric] = None,
@ -954,6 +958,7 @@ def train( # pylint: disable=unused-argument
xgb_model: Optional[Booster] = None, xgb_model: Optional[Booster] = None,
verbose_eval: Union[int, bool] = True, verbose_eval: Union[int, bool] = True,
callbacks: Optional[List[TrainingCallback]] = None, callbacks: Optional[List[TrainingCallback]] = None,
custom_metric: Optional[Metric] = None,
) -> Any: ) -> Any:
"""Train XGBoost model. """Train XGBoost model.
@ -1647,7 +1652,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
eval_metric: Optional[Union[str, List[str], Metric]], eval_metric: Optional[Union[str, List[str], Metric]],
sample_weight_eval_set: Optional[List[_DaskCollection]], sample_weight_eval_set: Optional[List[_DaskCollection]],
base_margin_eval_set: Optional[List[_DaskCollection]], base_margin_eval_set: Optional[List[_DaskCollection]],
early_stopping_rounds: int, early_stopping_rounds: Optional[int],
verbose: bool, verbose: bool,
xgb_model: Optional[Union[Booster, XGBModel]], xgb_model: Optional[Union[Booster, XGBModel]],
feature_weights: Optional[_DaskCollection], feature_weights: Optional[_DaskCollection],
@ -1676,8 +1681,8 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
obj: Optional[Callable] = _objective_decorator(self.objective) obj: Optional[Callable] = _objective_decorator(self.objective)
else: else:
obj = None obj = None
model, metric, params = self._configure_fit( model, metric, params, early_stopping_rounds = self._configure_fit(
booster=xgb_model, eval_metric=eval_metric, params=params xgb_model, eval_metric, params, early_stopping_rounds
) )
results = await self.client.sync( results = await self.client.sync(
_train_async, _train_async,
@ -1689,7 +1694,8 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
num_boost_round=self.get_num_boosting_rounds(), num_boost_round=self.get_num_boosting_rounds(),
evals=evals, evals=evals,
obj=obj, obj=obj,
feval=metric, feval=None,
custom_metric=metric,
verbose_eval=verbose, verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks, callbacks=callbacks,
@ -1736,7 +1742,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
eval_metric: Optional[Union[str, List[str], Metric]], eval_metric: Optional[Union[str, List[str], Metric]],
sample_weight_eval_set: Optional[List[_DaskCollection]], sample_weight_eval_set: Optional[List[_DaskCollection]],
base_margin_eval_set: Optional[List[_DaskCollection]], base_margin_eval_set: Optional[List[_DaskCollection]],
early_stopping_rounds: int, early_stopping_rounds: Optional[int],
verbose: bool, verbose: bool,
xgb_model: Optional[Union[Booster, XGBModel]], xgb_model: Optional[Union[Booster, XGBModel]],
feature_weights: Optional[_DaskCollection], feature_weights: Optional[_DaskCollection],
@ -1778,8 +1784,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
obj: Optional[Callable] = _objective_decorator(self.objective) obj: Optional[Callable] = _objective_decorator(self.objective)
else: else:
obj = None obj = None
model, metric, params = self._configure_fit( model, metric, params, early_stopping_rounds = self._configure_fit(
booster=xgb_model, eval_metric=eval_metric, params=params xgb_model, eval_metric, params, early_stopping_rounds
) )
results = await self.client.sync( results = await self.client.sync(
_train_async, _train_async,
@ -1791,7 +1797,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
num_boost_round=self.get_num_boosting_rounds(), num_boost_round=self.get_num_boosting_rounds(),
evals=evals, evals=evals,
obj=obj, obj=obj,
feval=metric, feval=None,
custom_metric=metric,
verbose_eval=verbose, verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks, callbacks=callbacks,
@ -1832,9 +1839,14 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
base_margin: Optional[_DaskCollection], base_margin: Optional[_DaskCollection],
iteration_range: Optional[Tuple[int, int]], iteration_range: Optional[Tuple[int, int]],
) -> _DaskCollection: ) -> _DaskCollection:
if self.objective == "multi:softmax":
raise ValueError(
"multi:softmax doesn't support `predict_proba`. "
"Switch to `multi:softproba` instead"
)
predts = await super()._predict_async( predts = await super()._predict_async(
data=X, data=X,
output_margin=self.objective == "multi:softmax", output_margin=False,
validate_features=validate_features, validate_features=validate_features,
base_margin=base_margin, base_margin=base_margin,
iteration_range=iteration_range, iteration_range=iteration_range,
@ -1903,9 +1915,9 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
""", """,
["estimators", "model"], ["estimators", "model"],
end_note=""" end_note="""
Note .. note::
----
For dask implementation, group is not supported, use qid instead. For dask implementation, group is not supported, use qid instead.
""", """,
) )
class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn): class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
@ -1929,7 +1941,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
eval_group: Optional[List[_DaskCollection]], eval_group: Optional[List[_DaskCollection]],
eval_qid: Optional[List[_DaskCollection]], eval_qid: Optional[List[_DaskCollection]],
eval_metric: Optional[Union[str, List[str], Metric]], eval_metric: Optional[Union[str, List[str], Metric]],
early_stopping_rounds: int, early_stopping_rounds: Optional[int],
verbose: bool, verbose: bool,
xgb_model: Optional[Union[XGBModel, Booster]], xgb_model: Optional[Union[XGBModel, Booster]],
feature_weights: Optional[_DaskCollection], feature_weights: Optional[_DaskCollection],
@ -1963,8 +1975,8 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
raise ValueError( raise ValueError(
"Custom evaluation metric is not yet supported for XGBRanker." "Custom evaluation metric is not yet supported for XGBRanker."
) )
model, metric, params = self._configure_fit( model, metric, params, early_stopping_rounds = self._configure_fit(
booster=xgb_model, eval_metric=eval_metric, params=params xgb_model, eval_metric, params, early_stopping_rounds
) )
results = await self.client.sync( results = await self.client.sync(
_train_async, _train_async,
@ -1976,7 +1988,8 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
num_boost_round=self.get_num_boosting_rounds(), num_boost_round=self.get_num_boosting_rounds(),
evals=evals, evals=evals,
obj=None, obj=None,
feval=metric, feval=None,
custom_metric=metric,
verbose_eval=verbose, verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks, callbacks=callbacks,

View File

@ -89,6 +89,19 @@ def _objective_decorator(
return inner return inner
def _metric_decorator(func: Callable) -> Metric:
"""Decorate a metric function from sklearn.
Converts an metric function that uses the typical sklearn metric signature so that it
is compatible with :py:func:`train`
"""
def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]:
y_true = dmatrix.get_label()
return func.__name__, func(y_true, y_score)
return inner
__estimator_doc = ''' __estimator_doc = '''
n_estimators : int n_estimators : int
Number of gradient boosted trees. Equivalent to number of boosting Number of gradient boosted trees. Equivalent to number of boosting
@ -183,6 +196,66 @@ __model_doc = f'''
Experimental support for categorical data. Do not set to true unless you are Experimental support for categorical data. Do not set to true unless you are
interested in development. Only valid when `gpu_hist` and dataframe are used. interested in development. Only valid when `gpu_hist` and dataframe are used.
eval_metric : Optional[Union[str, List[str], Callable]]
.. versionadded:: 1.5.1
Metric used for monitoring the training result and early stopping. It can be a
string or list of strings as names of predefined metric in XGBoost (See
doc/parameter.rst), one of the metrics in :py:mod:`sklearn.metrics`, or any other
user defined metric that looks like `sklearn.metrics`.
If custom objective is also provided, then custom metric should implement the
corresponding reverse link function.
Unlike the `scoring` parameter commonly used in scikit-learn, when a callable
object is provided, it's assumed to be a cost function and by default XGBoost will
minimize the result during early stopping.
For advanced usage on Early stopping like directly choosing to maximize instead of
minimize, see :py:obj:`xgboost.callback.EarlyStopping`.
See `Custom Objective and Evaluation Metric
<https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html>`_ for
more.
.. note::
This parameter replaces `eval_metric` in :py:meth:`fit` method. The old one
receives un-transformed prediction regardless of whether custom objective is
being used.
.. code-block:: python
from sklearn.datasets import load_diabetes
from sklearn.metrics import mean_absolute_error
X, y = load_diabetes(return_X_y=True)
reg = xgb.XGBRegressor(
tree_method="hist",
eval_metric=mean_absolute_error,
)
reg.fit(X, y, eval_set=[(X, y)])
early_stopping_rounds : Optional[int]
.. versionadded:: 1.5.1
Activates early stopping. Validation metric needs to improve at least once in
every **early_stopping_rounds** round(s) to continue training. Requires at least
one item in **eval_set** in :py:meth:`xgboost.sklearn.XGBModel.fit`.
The method returns the model from the last iteration (not the best one). If
there's more than one item in **eval_set**, the last entry will be used for early
stopping. If there's more than one metric in **eval_metric**, the last metric
will be used for early stopping.
If early stopping occurs, the model will have three additional fields:
``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``.
.. note::
This parameter replaces `early_stopping_rounds` in :py:meth:`fit` method.
kwargs : dict, optional kwargs : dict, optional
Keyword arguments for XGBoost Booster object. Full documentation of Keyword arguments for XGBoost Booster object. Full documentation of
parameters can be found here: parameters can be found here:
@ -397,6 +470,8 @@ class XGBModel(XGBModelBase):
validate_parameters: Optional[bool] = None, validate_parameters: Optional[bool] = None,
predictor: Optional[str] = None, predictor: Optional[str] = None,
enable_categorical: bool = False, enable_categorical: bool = False,
eval_metric: Optional[Union[str, List[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None,
**kwargs: Any **kwargs: Any
) -> None: ) -> None:
if not SKLEARN_INSTALLED: if not SKLEARN_INSTALLED:
@ -433,6 +508,8 @@ class XGBModel(XGBModelBase):
self.validate_parameters = validate_parameters self.validate_parameters = validate_parameters
self.predictor = predictor self.predictor = predictor
self.enable_categorical = enable_categorical self.enable_categorical = enable_categorical
self.eval_metric = eval_metric
self.early_stopping_rounds = early_stopping_rounds
if kwargs: if kwargs:
self.kwargs = kwargs self.kwargs = kwargs
@ -543,8 +620,13 @@ class XGBModel(XGBModelBase):
params = self.get_params() params = self.get_params()
# Parameters that should not go into native learner. # Parameters that should not go into native learner.
wrapper_specific = { wrapper_specific = {
'importance_type', 'kwargs', 'missing', 'n_estimators', 'use_label_encoder', "importance_type",
"enable_categorical" "kwargs",
"missing",
"n_estimators",
"use_label_encoder",
"enable_categorical",
"early_stopping_rounds",
} }
filtered = {} filtered = {}
for k, v in params.items(): for k, v in params.items():
@ -629,32 +711,80 @@ class XGBModel(XGBModelBase):
load_model.__doc__ = f"""{Booster.load_model.__doc__}""" load_model.__doc__ = f"""{Booster.load_model.__doc__}"""
# pylint: disable=too-many-branches
def _configure_fit( def _configure_fit(
self, self,
booster: Optional[Union[Booster, "XGBModel", str]], booster: Optional[Union[Booster, "XGBModel", str]],
eval_metric: Optional[Union[Callable, str, List[str]]], eval_metric: Optional[Union[Callable, str, List[str]]],
params: Dict[str, Any], params: Dict[str, Any],
) -> Tuple[Optional[Union[Booster, str]], Optional[Metric], Dict[str, Any]]: early_stopping_rounds: Optional[int],
# pylint: disable=protected-access, no-self-use ) -> Tuple[
Optional[Union[Booster, str, "XGBModel"]],
Optional[Metric],
Dict[str, Any],
Optional[int],
]:
"""Configure parameters for :py:meth:`fit`."""
if isinstance(booster, XGBModel): if isinstance(booster, XGBModel):
# Handle the case when xgb_model is a sklearn model object model: Optional[Union[Booster, str]] = booster.get_booster()
model: Optional[Union[Booster, str]] = booster._Booster
else: else:
model = booster model = booster
feval = eval_metric if callable(eval_metric) else None def _deprecated(parameter: str) -> None:
warnings.warn(
f"`{parameter}` in `fit` method is deprecated for better compatibility "
f"with scikit-learn, use `{parameter}` in constructor or`set_params` "
"instead.",
UserWarning,
)
def _duplicated(parameter: str) -> None:
raise ValueError(
f"2 different `{parameter}` are provided. Use the one in constructor "
"or `set_params` instead."
)
# Configure evaluation metric.
if eval_metric is not None: if eval_metric is not None:
if callable(eval_metric): _deprecated("eval_metric")
eval_metric = None if self.eval_metric is not None and eval_metric is not None:
_duplicated("eval_metric")
# - track where does the evaluation metric come from
if self.eval_metric is not None:
from_fit = False
eval_metric = self.eval_metric
else:
from_fit = True
# - configure callable evaluation metric
metric: Optional[Metric] = None
if eval_metric is not None:
if callable(eval_metric) and from_fit:
# No need to wrap the evaluation function for old parameter.
metric = eval_metric
elif callable(eval_metric):
# Parameter from constructor or set_params
metric = _metric_decorator(eval_metric)
else: else:
params.update({"eval_metric": eval_metric}) params.update({"eval_metric": eval_metric})
# Configure early_stopping_rounds
if early_stopping_rounds is not None:
_deprecated("early_stopping_rounds")
if early_stopping_rounds is not None and self.early_stopping_rounds is not None:
_duplicated("early_stopping_rounds")
early_stopping_rounds = (
self.early_stopping_rounds
if self.early_stopping_rounds is not None
else early_stopping_rounds
)
if self.enable_categorical and params.get("tree_method", None) != "gpu_hist": if self.enable_categorical and params.get("tree_method", None) != "gpu_hist":
raise ValueError( raise ValueError(
"Experimental support for categorical data is not implemented for" "Experimental support for categorical data is not implemented for"
" current tree method yet." " current tree method yet."
) )
return model, feval, params return model, metric, params, early_stopping_rounds
def _set_evaluation_result(self, evals_result: TrainingCallback.EvalsLog) -> None: def _set_evaluation_result(self, evals_result: TrainingCallback.EvalsLog) -> None:
if evals_result: if evals_result:
@ -702,31 +832,15 @@ class XGBModel(XGBModelBase):
A list of (X, y) tuple pairs to use as validation sets, for which A list of (X, y) tuple pairs to use as validation sets, for which
metrics will be computed. metrics will be computed.
Validation metrics will help us track the performance of the model. Validation metrics will help us track the performance of the model.
eval_metric :
If a str, should be a built-in evaluation metric to use. See doc/parameter.rst.
If a list of str, should be the list of multiple built-in evaluation metrics eval_metric : str, list of str, or callable, optional
to use. .. deprecated:: 1.5.1
Use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead.
If callable, a custom evaluation metric. The call signature is early_stopping_rounds : int
``func(y_predicted, y_true)`` where ``y_true`` will be a DMatrix object such .. deprecated:: 1.5.1
that you may need to call the ``get_label`` method. It must return a str, Use `early_stopping_rounds` in :py:meth:`__init__` or
value pair where the str is a name for the evaluation and value is the value :py:meth:`set_params` instead.
of the evaluation function. The callable custom objective is always minimized.
early_stopping_rounds :
Activates early stopping. Validation metric needs to improve at least once in
every **early_stopping_rounds** round(s) to continue training.
Requires at least one item in **eval_set**.
The method returns the model from the last iteration (not the best one).
If there's more than one item in **eval_set**, the last entry will be used
for early stopping.
If there's more than one metric in **eval_metric**, the last metric will be
used for early stopping.
If early stopping occurs, the model will have three additional fields:
``clf.best_score``, ``clf.best_iteration``.
verbose : verbose :
If `verbose` and an evaluation set is used, writes the evaluation metric If `verbose` and an evaluation set is used, writes the evaluation metric
measured on the validation set to stderr. measured on the validation set to stderr.
@ -783,7 +897,9 @@ class XGBModel(XGBModelBase):
else: else:
obj = None obj = None
model, feval, params = self._configure_fit(xgb_model, eval_metric, params) model, metric, params, early_stopping_rounds = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds
)
self._Booster = train( self._Booster = train(
params, params,
train_dmatrix, train_dmatrix,
@ -792,7 +908,7 @@ class XGBModel(XGBModelBase):
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, evals_result=evals_result,
obj=obj, obj=obj,
feval=feval, custom_metric=metric,
verbose_eval=verbose, verbose_eval=verbose,
xgb_model=model, xgb_model=model,
callbacks=callbacks, callbacks=callbacks,
@ -1185,12 +1301,14 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
obj = None obj = None
if self.n_classes_ > 2: if self.n_classes_ > 2:
# Switch to using a multiclass objective in the underlying # Switch to using a multiclass objective in the underlying XGB instance
# XGB instance if params.get("objective", None) != "multi:softmax":
params["objective"] = "multi:softprob" params["objective"] = "multi:softprob"
params["num_class"] = self.n_classes_ params["num_class"] = self.n_classes_
model, feval, params = self._configure_fit(xgb_model, eval_metric, params) model, metric, params, early_stopping_rounds = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds
)
train_dmatrix, evals = _wrap_evaluation_matrices( train_dmatrix, evals = _wrap_evaluation_matrices(
missing=self.missing, missing=self.missing,
X=X, X=X,
@ -1217,7 +1335,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, evals_result=evals_result,
obj=obj, obj=obj,
feval=feval, custom_metric=metric,
verbose_eval=verbose, verbose_eval=verbose,
xgb_model=model, xgb_model=model,
callbacks=callbacks, callbacks=callbacks,
@ -1304,12 +1422,19 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
""" """
# custom obj: Do nothing as we don't know what to do. # custom obj: Do nothing as we don't know what to do.
# softprob: Do nothing, output is proba. # softprob: Do nothing, output is proba.
# softmax: Use output margin to remove the argmax in PredTransform. # softmax: Unsupported by predict_proba()
# binary:logistic: Expand the prob vector into 2-class matrix after predict. # binary:logistic: Expand the prob vector into 2-class matrix after predict.
# binary:logitraw: Unsupported by predict_proba() # binary:logitraw: Unsupported by predict_proba()
if self.objective == "multi:softmax":
# We need to run a Python implementation of softmax for it. Just ask user to
# use softprob since XGBoost's implementation has mitigation for floating
# point overflow. No need to reinvent the wheel.
raise ValueError(
"multi:softmax doesn't support `predict_proba`. "
"Switch to `multi:softproba` instead"
)
class_probs = super().predict( class_probs = super().predict(
X=X, X=X,
output_margin=self.objective == "multi:softmax",
ntree_limit=ntree_limit, ntree_limit=ntree_limit,
validate_features=validate_features, validate_features=validate_features,
base_margin=base_margin, base_margin=base_margin,
@ -1325,8 +1450,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
If **eval_set** is passed to the `fit` function, you can call If **eval_set** is passed to the `fit` function, you can call
``evals_result()`` to get evaluation results for all passed **eval_sets**. ``evals_result()`` to get evaluation results for all passed **eval_sets**.
When **eval_metric** is also passed to the `fit` function, the
**evals_result** will contain the **eval_metrics** passed to the `fit` function. When **eval_metric** is also passed as a parameter, the **evals_result** will
contain the **eval_metric** passed to the `fit` function.
Returns Returns
------- -------
@ -1337,13 +1463,14 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
.. code-block:: python .. code-block:: python
param_dist = {'objective':'binary:logistic', 'n_estimators':2} param_dist = {
'objective':'binary:logistic', 'n_estimators':2, eval_metric="logloss"
}
clf = xgb.XGBClassifier(**param_dist) clf = xgb.XGBClassifier(**param_dist)
clf.fit(X_train, y_train, clf.fit(X_train, y_train,
eval_set=[(X_train, y_train), (X_test, y_test)], eval_set=[(X_train, y_train), (X_test, y_test)],
eval_metric='logloss',
verbose=True) verbose=True)
evals_result = clf.evals_result() evals_result = clf.evals_result()
@ -1354,6 +1481,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
{'validation_0': {'logloss': ['0.604835', '0.531479']}, {'validation_0': {'logloss': ['0.604835', '0.531479']},
'validation_1': {'logloss': ['0.41965', '0.17686']}} 'validation_1': {'logloss': ['0.41965', '0.17686']}}
""" """
if self.evals_result_: if self.evals_result_:
evals_result = self.evals_result_ evals_result = self.evals_result_
@ -1386,6 +1514,7 @@ class XGBRFClassifier(XGBClassifier):
colsample_bynode=colsample_bynode, colsample_bynode=colsample_bynode,
reg_lambda=reg_lambda, reg_lambda=reg_lambda,
**kwargs) **kwargs)
_check_rf_callback(self.early_stopping_rounds, None)
def get_xgb_params(self) -> Dict[str, Any]: def get_xgb_params(self) -> Dict[str, Any]:
params = super().get_xgb_params() params = super().get_xgb_params()
@ -1457,6 +1586,7 @@ class XGBRFRegressor(XGBRegressor):
reg_lambda=reg_lambda, reg_lambda=reg_lambda,
**kwargs **kwargs
) )
_check_rf_callback(self.early_stopping_rounds, None)
def get_xgb_params(self) -> Dict[str, Any]: def get_xgb_params(self) -> Dict[str, Any]:
params = super().get_xgb_params() params = super().get_xgb_params()
@ -1495,15 +1625,15 @@ class XGBRFRegressor(XGBRegressor):
'Implementation of the Scikit-Learn API for XGBoost Ranking.', 'Implementation of the Scikit-Learn API for XGBoost Ranking.',
['estimators', 'model'], ['estimators', 'model'],
end_note=''' end_note='''
Note .. note::
----
A custom objective function is currently not supported by XGBRanker.
Likewise, a custom metric function is not supported either.
Note A custom objective function is currently not supported by XGBRanker.
---- Likewise, a custom metric function is not supported either.
Query group information is required for ranking tasks by either using the `group`
parameter or `qid` parameter in `fit` method. .. note::
Query group information is required for ranking tasks by either using the
`group` parameter or `qid` parameter in `fit` method.
Before fitting the model, your data need to be sorted by query group. When fitting Before fitting the model, your data need to be sorted by query group. When fitting
the model, you need to provide an additional array that contains the size of each the model, you need to provide an additional array that contains the size of each
@ -1605,22 +1735,16 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
eval_qid : eval_qid :
A list in which ``eval_qid[i]`` is the array containing query ID of ``i``-th A list in which ``eval_qid[i]`` is the array containing query ID of ``i``-th
pair in **eval_set**. pair in **eval_set**.
eval_metric :
If a str, should be a built-in evaluation metric to use. See eval_metric : str, list of str, optional
doc/parameter.rst. .. deprecated:: 1.5.1
If a list of str, should be the list of multiple built-in evaluation metrics use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead.
to use. The custom evaluation metric is not yet supported for the ranker.
early_stopping_rounds : early_stopping_rounds : int
Activates early stopping. Validation metric needs to improve at least once in .. deprecated:: 1.5.1
every **early_stopping_rounds** round(s) to continue training. Requires at use `early_stopping_rounds` in :py:meth:`__init__` or
least one item in **eval_set**. :py:meth:`set_params` instead.
The method returns the model from the last iteration (not the best one). If
there's more than one item in **eval_set**, the last entry will be used for
early stopping.
If there's more than one metric in **eval_metric**, the last metric will be
used for early stopping.
If early stopping occurs, the model will have three additional fields:
``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``.
verbose : verbose :
If `verbose` and an evaluation set is used, writes the evaluation metric If `verbose` and an evaluation set is used, writes the evaluation metric
measured on the validation set to stderr. measured on the validation set to stderr.
@ -1685,8 +1809,10 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
evals_result: TrainingCallback.EvalsLog = {} evals_result: TrainingCallback.EvalsLog = {}
params = self.get_xgb_params() params = self.get_xgb_params()
model, feval, params = self._configure_fit(xgb_model, eval_metric, params) model, metric, params, early_stopping_rounds = self._configure_fit(
if callable(feval): xgb_model, eval_metric, params, early_stopping_rounds
)
if callable(metric):
raise ValueError( raise ValueError(
'Custom evaluation metric is not yet supported for XGBRanker.' 'Custom evaluation metric is not yet supported for XGBRanker.'
) )
@ -1696,7 +1822,8 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
self.n_estimators, self.n_estimators,
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
evals=evals, evals=evals,
evals_result=evals_result, feval=feval, evals_result=evals_result,
custom_metric=metric,
verbose_eval=verbose, xgb_model=model, verbose_eval=verbose, xgb_model=model,
callbacks=callbacks callbacks=callbacks
) )

View File

@ -4,9 +4,12 @@
"""Training Library containing training routines.""" """Training Library containing training routines."""
import copy import copy
from typing import Optional, List from typing import Optional, List
import warnings
import numpy as np import numpy as np
from .core import Booster, XGBoostError, _get_booster_layer_trees from .core import Booster, XGBoostError, _get_booster_layer_trees
from .core import _deprecate_positional_args
from .core import Objective, Metric
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold) from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
from . import callback from . import callback
@ -22,21 +25,48 @@ def _assert_new_callback(callbacks: Optional[List[callback.TrainingCallback]]) -
) )
def _train_internal(params, dtrain, def _configure_custom_metric(
num_boost_round=10, evals=(), feval: Optional[Metric], custom_metric: Optional[Metric]
obj=None, feval=None, ) -> Optional[Metric]:
xgb_model=None, callbacks=None, if feval is not None:
evals_result=None, maximize=None, link = "https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html"
verbose_eval=None, early_stopping_rounds=None): warnings.warn(
"`feval` is deprecated, use `custom_metric` instead. They have "
"different behavior when custom objective is also used."
f"See {link} for details on the `custom_metric`."
)
if feval is not None and custom_metric is not None:
raise ValueError(
"Bost `feval` and `custom_metric` are supplied. Use `custom_metric` instead."
)
eval_metric = custom_metric if custom_metric is not None else feval
return eval_metric
def _train_internal(
params,
dtrain,
num_boost_round=10,
evals=(),
obj=None,
feval=None,
custom_metric=None,
xgb_model=None,
callbacks=None,
evals_result=None,
maximize=None,
verbose_eval=None,
early_stopping_rounds=None,
):
"""internal training function""" """internal training function"""
callbacks = [] if callbacks is None else copy.copy(callbacks) callbacks = [] if callbacks is None else copy.copy(callbacks)
metric_fn = _configure_custom_metric(feval, custom_metric)
evals = list(evals) evals = list(evals)
bst = Booster(params, [dtrain] + [d[0] for d in evals]) bst = Booster(params, [dtrain] + [d[0] for d in evals])
if xgb_model is not None: if xgb_model is not None:
bst = Booster(params, [dtrain] + [d[0] for d in evals], bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
model_file=xgb_model)
start_iteration = 0 start_iteration = 0
@ -48,7 +78,14 @@ def _train_internal(params, dtrain,
callbacks.append( callbacks.append(
callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize) callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
) )
callbacks = callback.CallbackContainer(callbacks, metric=feval) callbacks = callback.CallbackContainer(
callbacks,
metric=metric_fn,
# For old `feval` parameter, the behavior is unchanged. For the new
# `custom_metric`, it will receive proper prediction result when custom objective
# is not used.
output_margin=callable(obj) or metric_fn is feval,
)
bst = callbacks.before_training(bst) bst = callbacks.before_training(bst)
@ -89,9 +126,23 @@ def _train_internal(params, dtrain,
return bst.copy() return bst.copy()
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, @_deprecate_positional_args
maximize=None, early_stopping_rounds=None, evals_result=None, def train(
verbose_eval=True, xgb_model=None, callbacks=None): params,
dtrain,
num_boost_round=10,
*,
evals=(),
obj: Optional[Objective] = None,
feval=None,
maximize=None,
early_stopping_rounds=None,
evals_result=None,
verbose_eval=True,
xgb_model=None,
callbacks=None,
custom_metric: Optional[Metric] = None,
):
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init # pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
"""Train a booster with given parameters. """Train a booster with given parameters.
@ -106,10 +157,13 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
evals: list of pairs (DMatrix, string) evals: list of pairs (DMatrix, string)
List of validation sets for which metrics will evaluated during training. List of validation sets for which metrics will evaluated during training.
Validation metrics will help us track the performance of the model. Validation metrics will help us track the performance of the model.
obj : function obj
Customized objective function. Custom objective function. See `Custom Objective
feval : function <https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html>`_ for
Customized evaluation function. details.
feval :
.. deprecated:: 1.5.1
Use `custom_metric` instead.
maximize : bool maximize : bool
Whether to maximize feval. Whether to maximize feval.
early_stopping_rounds: int early_stopping_rounds: int
@ -158,23 +212,37 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
[xgb.callback.LearningRateScheduler(custom_rates)] [xgb.callback.LearningRateScheduler(custom_rates)]
custom_metric:
.. versionadded 1.5.1
Custom metric function. See `Custom Metric
<https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html>`_ for
details.
Returns Returns
------- -------
Booster : a trained booster model Booster : a trained booster model
""" """
bst = _train_internal(params, dtrain, bst = _train_internal(
num_boost_round=num_boost_round, params,
evals=evals, dtrain,
obj=obj, feval=feval, num_boost_round=num_boost_round,
xgb_model=xgb_model, callbacks=callbacks, evals=evals,
verbose_eval=verbose_eval, obj=obj,
evals_result=evals_result, feval=feval,
maximize=maximize, xgb_model=xgb_model,
early_stopping_rounds=early_stopping_rounds) callbacks=callbacks,
verbose_eval=verbose_eval,
evals_result=evals_result,
maximize=maximize,
early_stopping_rounds=early_stopping_rounds,
custom_metric=custom_metric,
)
return bst return bst
class CVPack(object): class CVPack:
""""Auxiliary datastruct to hold one fold of CV.""" """"Auxiliary datastruct to hold one fold of CV."""
def __init__(self, dtrain, dtest, param): def __init__(self, dtrain, dtest, param):
""""Initialize the CVPack""" """"Initialize the CVPack"""
@ -192,9 +260,9 @@ class CVPack(object):
""""Update the boosters for one iteration""" """"Update the boosters for one iteration"""
self.bst.update(self.dtrain, iteration, fobj) self.bst.update(self.dtrain, iteration, fobj)
def eval(self, iteration, feval): def eval(self, iteration, feval, output_margin):
""""Evaluate the CVPack for one iteration.""" """"Evaluate the CVPack for one iteration."""
return self.bst.eval_set(self.watchlist, iteration, feval) return self.bst.eval_set(self.watchlist, iteration, feval, output_margin)
class _PackedBooster: class _PackedBooster:
@ -206,9 +274,9 @@ class _PackedBooster:
for fold in self.cvfolds: for fold in self.cvfolds:
fold.update(iteration, obj) fold.update(iteration, obj)
def eval(self, iteration, feval): def eval(self, iteration, feval, output_margin):
'''Iterate through folds for eval''' '''Iterate through folds for eval'''
result = [f.eval(iteration, feval) for f in self.cvfolds] result = [f.eval(iteration, feval, output_margin) for f in self.cvfolds]
return result return result
def set_attr(self, **kwargs): def set_attr(self, **kwargs):
@ -345,9 +413,10 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None, def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None,
metrics=(), obj=None, feval=None, maximize=None, early_stopping_rounds=None, metrics=(), obj: Optional[Objective] = None,
feval=None, maximize=None, early_stopping_rounds=None,
fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True, fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True,
seed=0, callbacks=None, shuffle=True): seed=0, callbacks=None, shuffle=True, custom_metric: Optional[Metric] = None):
# pylint: disable = invalid-name # pylint: disable = invalid-name
"""Cross-validation with given parameters. """Cross-validation with given parameters.
@ -372,10 +441,15 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
indices to be used as the testing samples for the ``n`` th fold. indices to be used as the testing samples for the ``n`` th fold.
metrics : string or list of strings metrics : string or list of strings
Evaluation metrics to be watched in CV. Evaluation metrics to be watched in CV.
obj : function obj :
Custom objective function.
Custom objective function. See `Custom Objective
<https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html>`_ for
details.
feval : function feval : function
Custom evaluation function. .. deprecated:: 1.5.1
Use `custom_metric` instead.
maximize : bool maximize : bool
Whether to maximize feval. Whether to maximize feval.
early_stopping_rounds: int early_stopping_rounds: int
@ -412,6 +486,13 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
[xgb.callback.LearningRateScheduler(custom_rates)] [xgb.callback.LearningRateScheduler(custom_rates)]
shuffle : bool shuffle : bool
Shuffle data before creating folds. Shuffle data before creating folds.
custom_metric :
.. versionadded 1.5.1
Custom metric function. See `Custom Metric
<https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html>`_ for
details.
Returns Returns
------- -------
@ -443,6 +524,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc, cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc,
stratified, folds, shuffle) stratified, folds, shuffle)
metric_fn = _configure_custom_metric(feval, custom_metric)
# setup callbacks # setup callbacks
callbacks = [] if callbacks is None else callbacks callbacks = [] if callbacks is None else callbacks
_assert_new_callback(callbacks) _assert_new_callback(callbacks)
@ -456,7 +539,12 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
callbacks.append( callbacks.append(
callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize) callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
) )
callbacks = callback.CallbackContainer(callbacks, metric=feval, is_cv=True) callbacks = callback.CallbackContainer(
callbacks,
metric=metric_fn,
is_cv=True,
output_margin=callable(obj) or metric_fn is feval,
)
booster = _PackedBooster(cvfolds) booster = _PackedBooster(cvfolds)
callbacks.before_training(booster) callbacks.before_training(booster)

View File

@ -173,10 +173,11 @@ class TestCallbacks:
def test_early_stopping_skl(self): def test_early_stopping_skl(self):
from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
cls = xgb.XGBClassifier()
early_stopping_rounds = 5 early_stopping_rounds = 5
cls.fit(X, y, eval_set=[(X, y)], cls = xgb.XGBClassifier(
early_stopping_rounds=early_stopping_rounds, eval_metric='error') early_stopping_rounds=early_stopping_rounds, eval_metric='error'
)
cls.fit(X, y, eval_set=[(X, y)])
booster = cls.get_booster() booster = cls.get_booster()
dump = booster.get_dump(dump_format='json') dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
@ -184,12 +185,10 @@ class TestCallbacks:
def test_early_stopping_custom_eval_skl(self): def test_early_stopping_custom_eval_skl(self):
from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
cls = xgb.XGBClassifier() cls = xgb.XGBClassifier(eval_metric=tm.eval_error_metric_skl)
early_stopping_rounds = 5 early_stopping_rounds = 5
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds) early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds)
cls.fit(X, y, eval_set=[(X, y)], cls.fit(X, y, eval_set=[(X, y)], callbacks=[early_stop])
eval_metric=tm.eval_error_metric,
callbacks=[early_stop])
booster = cls.get_booster() booster = cls.get_booster()
dump = booster.get_dump(dump_format='json') dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
@ -198,41 +197,40 @@ class TestCallbacks:
from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
n_estimators = 100 n_estimators = 100
cls = xgb.XGBClassifier(n_estimators=n_estimators) cls = xgb.XGBClassifier(
n_estimators=n_estimators, eval_metric=tm.eval_error_metric_skl
)
early_stopping_rounds = 5 early_stopping_rounds = 5
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=True) save_best=True)
cls.fit(X, y, eval_set=[(X, y)], cls.fit(X, y, eval_set=[(X, y)], callbacks=[early_stop])
eval_metric=tm.eval_error_metric, callbacks=[early_stop])
booster = cls.get_booster() booster = cls.get_booster()
dump = booster.get_dump(dump_format='json') dump = booster.get_dump(dump_format='json')
assert len(dump) == booster.best_iteration + 1 assert len(dump) == booster.best_iteration + 1
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=True) save_best=True)
cls = xgb.XGBClassifier(booster='gblinear', n_estimators=10) cls = xgb.XGBClassifier(
booster='gblinear', n_estimators=10, eval_metric=tm.eval_error_metric_skl
)
with pytest.raises(ValueError): with pytest.raises(ValueError):
cls.fit(X, y, eval_set=[(X, y)], eval_metric=tm.eval_error_metric, cls.fit(X, y, eval_set=[(X, y)], callbacks=[early_stop])
callbacks=[early_stop])
# No error # No error
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=False) save_best=False)
xgb.XGBClassifier(booster='gblinear', n_estimators=10).fit( xgb.XGBClassifier(
X, y, eval_set=[(X, y)], booster='gblinear', n_estimators=10, eval_metric=tm.eval_error_metric_skl
eval_metric=tm.eval_error_metric, ).fit(X, y, eval_set=[(X, y)], callbacks=[early_stop])
callbacks=[early_stop])
def test_early_stopping_continuation(self): def test_early_stopping_continuation(self):
from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
cls = xgb.XGBClassifier() cls = xgb.XGBClassifier(eval_metric=tm.eval_error_metric_skl)
early_stopping_rounds = 5 early_stopping_rounds = 5
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=True) save_best=True)
cls.fit(X, y, eval_set=[(X, y)], cls.fit(X, y, eval_set=[(X, y)], callbacks=[early_stop])
eval_metric=tm.eval_error_metric,
callbacks=[early_stop])
booster = cls.get_booster() booster = cls.get_booster()
assert booster.num_boosted_rounds() == booster.best_iteration + 1 assert booster.num_boosted_rounds() == booster.best_iteration + 1
@ -243,8 +241,8 @@ class TestCallbacks:
cls.load_model(path) cls.load_model(path)
assert cls._Booster is not None assert cls._Booster is not None
early_stopping_rounds = 3 early_stopping_rounds = 3
cls.fit(X, y, eval_set=[(X, y)], eval_metric=tm.eval_error_metric, cls.set_params(eval_metric=tm.eval_error_metric_skl)
early_stopping_rounds=early_stopping_rounds) cls.fit(X, y, eval_set=[(X, y)], early_stopping_rounds=early_stopping_rounds)
booster = cls.get_booster() booster = cls.get_booster()
assert booster.num_boosted_rounds() == \ assert booster.num_boosted_rounds() == \
booster.best_iteration + early_stopping_rounds + 1 booster.best_iteration + early_stopping_rounds + 1

View File

@ -7,7 +7,6 @@ rng = np.random.RandomState(1994)
class TestEarlyStopping: class TestEarlyStopping:
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
def test_early_stopping_nonparallel(self): def test_early_stopping_nonparallel(self):
from sklearn.datasets import load_digits from sklearn.datasets import load_digits

View File

@ -1663,11 +1663,16 @@ class TestDaskCallbacks:
valid_X, valid_y = load_breast_cancer(return_X_y=True) valid_X, valid_y = load_breast_cancer(return_X_y=True)
valid_X, valid_y = da.from_array(valid_X), da.from_array(valid_y) valid_X, valid_y = da.from_array(valid_X), da.from_array(valid_y)
cls = xgb.dask.DaskXGBClassifier(objective='binary:logistic', tree_method='hist', cls = xgb.dask.DaskXGBClassifier(
n_estimators=1000) objective='binary:logistic',
tree_method='hist',
n_estimators=1000,
eval_metric=tm.eval_error_metric_skl
)
cls.client = client cls.client = client
cls.fit(X, y, early_stopping_rounds=early_stopping_rounds, cls.fit(
eval_set=[(valid_X, valid_y)], eval_metric=tm.eval_error_metric) X, y, early_stopping_rounds=early_stopping_rounds, eval_set=[(valid_X, valid_y)]
)
booster = cls.get_booster() booster = cls.get_booster()
dump = booster.get_dump(dump_format='json') dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 assert len(dump) - booster.best_iteration == early_stopping_rounds + 1

View File

@ -1271,3 +1271,76 @@ def test_prediction_config():
reg.set_params(booster="gblinear") reg.set_params(booster="gblinear")
assert reg._can_use_inplace_predict() is False assert reg._can_use_inplace_predict() is False
def test_evaluation_metric():
from sklearn.datasets import load_diabetes, load_digits
from sklearn.metrics import mean_absolute_error
X, y = load_diabetes(return_X_y=True)
n_estimators = 16
with tm.captured_output() as (out, err):
reg = xgb.XGBRegressor(
tree_method="hist",
eval_metric=mean_absolute_error,
n_estimators=n_estimators,
)
reg.fit(X, y, eval_set=[(X, y)])
lines = out.getvalue().strip().split('\n')
assert len(lines) == n_estimators
for line in lines:
assert line.find("mean_absolute_error") != -1
def metric(predt: np.ndarray, Xy: xgb.DMatrix):
y = Xy.get_label()
return "m", np.abs(predt - y).sum()
with pytest.warns(UserWarning):
reg = xgb.XGBRegressor(
tree_method="hist",
n_estimators=1,
)
reg.fit(X, y, eval_set=[(X, y)], eval_metric=metric)
def merror(y_true: np.ndarray, predt: np.ndarray):
n_samples = y_true.shape[0]
assert n_samples == predt.size
errors = np.zeros(y_true.shape[0])
errors[y != predt] = 1.0
return np.sum(errors) / n_samples
X, y = load_digits(n_class=10, return_X_y=True)
clf = xgb.XGBClassifier(
use_label_encoder=False,
tree_method="hist",
eval_metric=merror,
n_estimators=16,
objective="multi:softmax"
)
clf.fit(X, y, eval_set=[(X, y)])
custom = clf.evals_result()
clf = xgb.XGBClassifier(
use_label_encoder=False,
tree_method="hist",
eval_metric="merror",
n_estimators=16,
objective="multi:softmax"
)
clf.fit(X, y, eval_set=[(X, y)])
internal = clf.evals_result()
np.testing.assert_allclose(
custom["validation_0"]["merror"], internal["validation_0"]["merror"]
)
clf = xgb.XGBRFClassifier(
use_label_encoder=False,
tree_method="hist", n_estimators=16,
objective=tm.softprob_obj(10),
eval_metric=merror,
)
with pytest.raises(AssertionError):
# shape check inside the `merror` function
clf.fit(X, y, eval_set=[(X, y)])

View File

@ -338,6 +338,7 @@ def non_increasing(L, tolerance=1e-4):
def eval_error_metric(predt, dtrain: xgb.DMatrix): def eval_error_metric(predt, dtrain: xgb.DMatrix):
"""Evaluation metric for xgb.train"""
label = dtrain.get_label() label = dtrain.get_label()
r = np.zeros(predt.shape) r = np.zeros(predt.shape)
gt = predt > 0.5 gt = predt > 0.5
@ -349,6 +350,16 @@ def eval_error_metric(predt, dtrain: xgb.DMatrix):
return 'CustomErr', np.sum(r) return 'CustomErr', np.sum(r)
def eval_error_metric_skl(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Evaluation metric that looks like metrics provided by sklearn."""
r = np.zeros(y_score.shape)
gt = y_score > 0.5
r[gt] = 1 - y_true[gt]
le = y_score <= 0.5
r[le] = y_true[le]
return np.sum(r)
def softmax(x): def softmax(x):
e = np.exp(x) e = np.exp(x)
return e / np.sum(e) return e / np.sum(e)