From 45aef75ccaf79e10bdf081cf0b60e6c8f42116c5 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 28 Oct 2021 17:20:20 +0800 Subject: [PATCH] Move skl `eval_metric` and `early_stopping rounds` to model params. (#6751) A new parameter `custom_metric` is added to `train` and `cv` to distinguish the behaviour from the old `feval`. And `feval` is deprecated. The new `custom_metric` receives transformed prediction when the built-in objective is used. This enables XGBoost to use cost functions from other libraries like scikit-learn directly without going through the definition of the link function. `eval_metric` and `early_stopping_rounds` in sklearn interface are moved from `fit` to `__init__` and is now saved as part of the scikit-learn model. The old ones in `fit` function are now deprecated. The new `eval_metric` in `__init__` has the same new behaviour as `custom_metric`. Added more detailed documents for the behaviour of custom objective and metric. --- demo/guide-python/custom_rmsle.py | 2 +- demo/guide-python/custom_softmax.py | 13 +- doc/tutorials/custom_metric_obj.rst | 198 ++++++++++++++++++-- python-package/xgboost/callback.py | 16 +- python-package/xgboost/core.py | 26 ++- python-package/xgboost/dask.py | 45 +++-- python-package/xgboost/sklearn.py | 273 ++++++++++++++++++++-------- python-package/xgboost/training.py | 160 ++++++++++++---- tests/python/test_callback.py | 44 +++-- tests/python/test_early_stopping.py | 1 - tests/python/test_with_dask.py | 13 +- tests/python/test_with_sklearn.py | 73 ++++++++ tests/python/testing.py | 11 ++ 13 files changed, 685 insertions(+), 190 deletions(-) diff --git a/demo/guide-python/custom_rmsle.py b/demo/guide-python/custom_rmsle.py index 6292636cb..0f8d5fcb2 100644 --- a/demo/guide-python/custom_rmsle.py +++ b/demo/guide-python/custom_rmsle.py @@ -144,7 +144,7 @@ def py_rmsle(dtrain: xgb.DMatrix, dtest: xgb.DMatrix) -> Dict: dtrain=dtrain, num_boost_round=kBoostRound, obj=squared_log, - feval=rmsle, + custom_metric=rmsle, evals=[(dtrain, 'dtrain'), (dtest, 'dtest')], evals_result=results) diff --git a/demo/guide-python/custom_softmax.py b/demo/guide-python/custom_softmax.py index e54853aa0..bb53d6e5c 100644 --- a/demo/guide-python/custom_softmax.py +++ b/demo/guide-python/custom_softmax.py @@ -3,6 +3,9 @@ only applicable after (excluding) XGBoost 1.0.0, as before this version XGBoost returns transformed prediction for multi-class objective function. More details in comments. +See https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html for detailed +tutorial and notes. + ''' import numpy as np @@ -95,7 +98,12 @@ def predict(booster: xgb.Booster, X): def merror(predt: np.ndarray, dtrain: xgb.DMatrix): y = dtrain.get_label() - # Like custom objective, the predt is untransformed leaf weight + # Like custom objective, the predt is untransformed leaf weight when custom objective + # is provided. + + # With the use of `custom_metric` parameter in train function, custom metric receives + # raw input only when custom objective is also being used. Otherwise custom metric + # will receive transformed prediction. assert predt.shape == (kRows, kClasses) out = np.zeros(kRows) for r in range(predt.shape[0]): @@ -134,7 +142,7 @@ def main(args): m, num_boost_round=kRounds, obj=softprob_obj, - feval=merror, + custom_metric=merror, evals_result=custom_results, evals=[(m, 'train')]) @@ -143,6 +151,7 @@ def main(args): native_results = {} # Use the same objective function defined in XGBoost. booster_native = xgb.train({'num_class': kClasses, + "objective": "multi:softmax", 'eval_metric': 'merror'}, m, num_boost_round=kRounds, diff --git a/doc/tutorials/custom_metric_obj.rst b/doc/tutorials/custom_metric_obj.rst index eeb7e728a..5dbab173b 100644 --- a/doc/tutorials/custom_metric_obj.rst +++ b/doc/tutorials/custom_metric_obj.rst @@ -2,6 +2,16 @@ Custom Objective and Evaluation Metric ###################################### +**Contents** + +.. contents:: + :backlinks: none + :local: + +******** +Overview +******** + XGBoost is designed to be an extensible library. One way to extend it is by providing our own objective function for training and corresponding metric for performance monitoring. This document introduces implementing a customized elementwise evaluation metric and @@ -11,12 +21,8 @@ concepts should be readily applicable to other language bindings. .. note:: * The ranking task does not support customized functions. - * The customized functions defined here are only applicable to single node training. - Distributed environment requires syncing with ``xgboost.rabit``, the interface is - subject to change hence beyond the scope of this tutorial. - * We also plan to improve the interface for multi-classes objective in the future. -In the following sections, we will provide a step by step walk through of implementing +In the following two sections, we will provide a step by step walk through of implementing ``Squared Log Error(SLE)`` objective function: .. math:: @@ -30,7 +36,10 @@ and its default metric ``Root Mean Squared Log Error(RMSLE)``: Although XGBoost has native support for said functions, using it for demonstration provides us the opportunity of comparing the result from our own implementation and the one from XGBoost internal for learning purposes. After finishing this tutorial, we should -be able to provide our own functions for rapid experiments. +be able to provide our own functions for rapid experiments. And at the end, we will +provide some notes on non-identy link function along with examples of using custom metric +and objective with `scikit-learn` interface. +with scikit-learn interface. ***************************** Customized Objective Function @@ -125,12 +134,12 @@ We will be able to see XGBoost printing something like: .. code-block:: none - [0] dtrain-PyRMSLE:1.37153 dtest-PyRMSLE:1.31487 - [1] dtrain-PyRMSLE:1.26619 dtest-PyRMSLE:1.20899 - [2] dtrain-PyRMSLE:1.17508 dtest-PyRMSLE:1.11629 - [3] dtrain-PyRMSLE:1.09836 dtest-PyRMSLE:1.03871 - [4] dtrain-PyRMSLE:1.03557 dtest-PyRMSLE:0.977186 - [5] dtrain-PyRMSLE:0.985783 dtest-PyRMSLE:0.93057 + [0] dtrain-PyRMSLE:1.37153 dtest-PyRMSLE:1.31487 + [1] dtrain-PyRMSLE:1.26619 dtest-PyRMSLE:1.20899 + [2] dtrain-PyRMSLE:1.17508 dtest-PyRMSLE:1.11629 + [3] dtrain-PyRMSLE:1.09836 dtest-PyRMSLE:1.03871 + [4] dtrain-PyRMSLE:1.03557 dtest-PyRMSLE:0.977186 + [5] dtrain-PyRMSLE:0.985783 dtest-PyRMSLE:0.93057 ... Notice that the parameter ``disable_default_eval_metric`` is used to suppress the default metric @@ -138,11 +147,164 @@ in XGBoost. For fully reproducible source code and comparison plots, see `custom_rmsle.py `_. +********************* +Reverse Link Function +********************* -****************************** -Multi-class objective function -****************************** +When using builtin objective, the raw prediction is transformed according to the objective +function. When custom objective is provided XGBoost doesn't know its link function so the +user is responsible for making the transformation for both objective and custom evaluation +metric. For objective with identiy link like ``squared error`` this is trivial, but for +other link functions like log link or inverse link the difference is significant. -A similar demo for multi-class objective function is also available, see -`demo/guide-python/custom_softmax.py `_ -for details. +For the Python package, the behaviour of prediction can be controlled by the +``output_margin`` parameter in ``predict`` function. When using the ``custom_metric`` +parameter without a custom objective, the metric function will receive transformed +prediction since the objective is defined by XGBoost. However, when custom objective is +also provided along with that metric, then both the objective and custom metric will +recieve raw prediction. Following example provides a comparison between two different +behavior with a multi-class classification model. Firstly we define 2 different Python +metric functions implementing the same underlying metric for comparison, +`merror_with_transform` is used when custom objective is also used, otherwise the simpler +`merror` is preferred since XGBoost can perform the transformation itself. + +.. code-block:: python + + import xgboost as xgb + import numpy as np + + def merror_with_transform(predt: np.ndarray, dtrain: xgb.DMatrix): + """Used when custom objective is supplied.""" + y = dtrain.get_label() + n_classes = predt.size // y.shape[0] + # Like custom objective, the predt is untransformed leaf weight when custom objective + # is provided. + + # With the use of `custom_metric` parameter in train function, custom metric receives + # raw input only when custom objective is also being used. Otherwise custom metric + # will receive transformed prediction. + assert predt.shape == (d_train.num_row(), n_classes) + out = np.zeros(dtrain.num_row()) + for r in range(predt.shape[0]): + i = np.argmax(predt[r]) + out[r] = i + + assert y.shape == out.shape + + errors = np.zeros(dtrain.num_row()) + errors[y != out] = 1.0 + return 'PyMError', np.sum(errors) / dtrain.num_row() + +The above function is only needed when we want to use custom objective and XGBoost doesn't +know how to transform the prediction. The normal implementation for multi-class error +function is: + +.. code-block:: python + + def merror(predt: np.ndarray, dtrain: xgb.DMatrix): + """Used when there's no custom objective.""" + # No need to do transform, XGBoost handles it internally. + errors = np.zeros(dtrain.num_row()) + errors[y != out] = 1.0 + return 'PyMError', np.sum(errors) / dtrain.num_row() + + +Next we need the custom softprob objective: + +.. code-block:: python + + def softprob_obj(predt: np.ndarray, data: xgb.DMatrix): + """Loss function. Computing the gradient and approximated hessian (diagonal). + Reimplements the `multi:softprob` inside XGBoost. + """ + + # Full implementation is available in the Python demo script linked below + ... + + return grad, hess + +Lastly we can train the model using ``obj`` and ``custom_metric`` parameters: + +.. code-block:: python + + Xy = xgb.DMatrix(X, y) + booster = xgb.train( + {"num_class": kClasses, "disable_default_eval_metric": True}, + m, + num_boost_round=kRounds, + obj=softprob_obj, + custom_metric=merror_with_transform, + evals_result=custom_results, + evals=[(m, "train")], + ) + +Or if you don't need the custom objective and just want to supply a metric that's not +available in XGBoost: + +.. code-block:: python + + booster = xgb.train( + { + "num_class": kClasses, + "disable_default_eval_metric": True, + "objective": "multi:softmax", + }, + m, + num_boost_round=kRounds, + # Use a simpler metric implementation. + custom_metric=merror, + evals_result=custom_results, + evals=[(m, "train")], + ) + +We use ``multi:softmax`` to illustrate the differences of transformed prediction. With +``softprob`` the output prediction array has shape ``(n_samples, n_classes)`` while for +``softmax`` it's ``(n_samples, )``. A demo for multi-class objective function is also +available at `demo/guide-python/custom_softmax.py +`_ + + +********************** +Scikit-Learn Interface +********************** + + +The scikit-learn interface of XGBoost has some utilities to improve the integration with +standard scikit-learn functions. For instance, after XGBoost 1.5.1 users can use the cost +function (not scoring functions) from scikit-learn out of the box: + +.. code-block:: python + + from sklearn.datasets import load_diabetes + from sklearn.metrics import mean_absolute_error + X, y = load_diabetes(return_X_y=True) + reg = xgb.XGBRegressor( + tree_method="hist", + eval_metric=mean_absolute_error, + ) + reg.fit(X, y, eval_set=[(X, y)]) + +Also, for custom objective function, users can define the objective without having to +access ``DMatrix``: + +.. code-block:: python + + def softprob_obj(labels: np.ndarray, predt: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + rows = labels.shape[0] + grad = np.zeros((rows, classes), dtype=float) + hess = np.zeros((rows, classes), dtype=float) + eps = 1e-6 + for r in range(predt.shape[0]): + target = labels[r] + p = softmax(predt[r, :]) + for c in range(predt.shape[1]): + g = p[c] - 1.0 if c == target else p[c] + h = max((2.0 * p[c] * (1.0 - p[c])).item(), eps) + grad[r, c] = g + hess[r, c] = h + + grad = grad.reshape((rows * classes, 1)) + hess = hess.reshape((rows * classes, 1)) + return grad, hess + + clf = xgb.XGBClassifier(tree_method="hist", objective=softprob_obj) diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py index 7a5504bd2..7552db79d 100644 --- a/python-package/xgboost/callback.py +++ b/python-package/xgboost/callback.py @@ -103,10 +103,13 @@ class CallbackContainer: EvalsLog = TrainingCallback.EvalsLog - def __init__(self, - callbacks: List[TrainingCallback], - metric: Callable = None, - is_cv: bool = False): + def __init__( + self, + callbacks: List[TrainingCallback], + metric: Callable = None, + output_margin: bool = True, + is_cv: bool = False + ) -> None: self.callbacks = set(callbacks) if metric is not None: msg = 'metric must be callable object for monitoring. For ' + \ @@ -115,6 +118,7 @@ class CallbackContainer: assert callable(metric), msg self.metric = metric self.history: TrainingCallback.EvalsLog = collections.OrderedDict() + self._output_margin = output_margin self.is_cv = is_cv if self.is_cv: @@ -171,7 +175,7 @@ class CallbackContainer: def after_iteration(self, model, epoch, dtrain, evals) -> bool: '''Function called after training iteration.''' if self.is_cv: - scores = model.eval(epoch, self.metric) + scores = model.eval(epoch, self.metric, self._output_margin) scores = _aggcv(scores) self.aggregated_cv = scores self._update_history(scores, epoch) @@ -179,7 +183,7 @@ class CallbackContainer: evals = [] if evals is None else evals for _, name in evals: assert name.find('-') == -1, 'Dataset name should not contain `-`' - score = model.eval_set(evals, epoch, self.metric) + score = model.eval_set(evals, epoch, self.metric, self._output_margin) score = score.split()[1:] # into datasets # split up `test-error:0.1234` score = [tuple(s.split(':')) for s in score] diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index ddc1ab969..e53fdb21d 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1700,7 +1700,7 @@ class Booster(object): c_array(ctypes.c_float, hess), c_bst_ulong(len(grad)))) - def eval_set(self, evals, iteration=0, feval=None): + def eval_set(self, evals, iteration=0, feval=None, output_margin=True): # pylint: disable=invalid-name """Evaluate a set of data. @@ -1728,24 +1728,30 @@ class Booster(object): dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals]) evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals]) msg = ctypes.c_char_p() - _check_call(_LIB.XGBoosterEvalOneIter(self.handle, - ctypes.c_int(iteration), - dmats, evnames, - c_bst_ulong(len(evals)), - ctypes.byref(msg))) + _check_call( + _LIB.XGBoosterEvalOneIter( + self.handle, + ctypes.c_int(iteration), + dmats, + evnames, + c_bst_ulong(len(evals)), + ctypes.byref(msg), + ) + ) res = msg.value.decode() # pylint: disable=no-member if feval is not None: for dmat, evname in evals: - feval_ret = feval(self.predict(dmat, training=False, - output_margin=True), dmat) + feval_ret = feval( + self.predict(dmat, training=False, output_margin=output_margin), dmat + ) if isinstance(feval_ret, list): for name, val in feval_ret: # pylint: disable=consider-using-f-string - res += '\t%s-%s:%f' % (evname, name, val) + res += "\t%s-%s:%f" % (evname, name, val) else: name, val = feval_ret # pylint: disable=consider-using-f-string - res += '\t%s-%s:%f' % (evname, name, val) + res += "\t%s-%s:%f" % (evname, name, val) return res def eval(self, data, name='eval', iteration=0): diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 57df22a5c..e96a21a12 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -844,6 +844,7 @@ async def _train_async( verbose_eval: Union[int, bool], xgb_model: Optional[Booster], callbacks: Optional[List[TrainingCallback]], + custom_metric: Optional[Metric], ) -> Optional[TrainReturnT]: workers = _get_workers_from_data(dtrain, evals) _rabit_args = await _get_rabit_args(len(workers), client) @@ -896,6 +897,7 @@ async def _train_async( evals=local_evals, obj=obj, feval=feval, + custom_metric=custom_metric, early_stopping_rounds=early_stopping_rounds, verbose_eval=verbose_eval, xgb_model=xgb_model, @@ -942,11 +944,13 @@ async def _train_async( return list(filter(lambda ret: ret is not None, results))[0] +@_deprecate_positional_args def train( # pylint: disable=unused-argument client: "distributed.Client", params: Dict[str, Any], dtrain: DaskDMatrix, num_boost_round: int = 10, + *, evals: Optional[List[Tuple[DaskDMatrix, str]]] = None, obj: Optional[Objective] = None, feval: Optional[Metric] = None, @@ -954,6 +958,7 @@ def train( # pylint: disable=unused-argument xgb_model: Optional[Booster] = None, verbose_eval: Union[int, bool] = True, callbacks: Optional[List[TrainingCallback]] = None, + custom_metric: Optional[Metric] = None, ) -> Any: """Train XGBoost model. @@ -1647,7 +1652,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase): eval_metric: Optional[Union[str, List[str], Metric]], sample_weight_eval_set: Optional[List[_DaskCollection]], base_margin_eval_set: Optional[List[_DaskCollection]], - early_stopping_rounds: int, + early_stopping_rounds: Optional[int], verbose: bool, xgb_model: Optional[Union[Booster, XGBModel]], feature_weights: Optional[_DaskCollection], @@ -1676,8 +1681,8 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase): obj: Optional[Callable] = _objective_decorator(self.objective) else: obj = None - model, metric, params = self._configure_fit( - booster=xgb_model, eval_metric=eval_metric, params=params + model, metric, params, early_stopping_rounds = self._configure_fit( + xgb_model, eval_metric, params, early_stopping_rounds ) results = await self.client.sync( _train_async, @@ -1689,7 +1694,8 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase): num_boost_round=self.get_num_boosting_rounds(), evals=evals, obj=obj, - feval=metric, + feval=None, + custom_metric=metric, verbose_eval=verbose, early_stopping_rounds=early_stopping_rounds, callbacks=callbacks, @@ -1736,7 +1742,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): eval_metric: Optional[Union[str, List[str], Metric]], sample_weight_eval_set: Optional[List[_DaskCollection]], base_margin_eval_set: Optional[List[_DaskCollection]], - early_stopping_rounds: int, + early_stopping_rounds: Optional[int], verbose: bool, xgb_model: Optional[Union[Booster, XGBModel]], feature_weights: Optional[_DaskCollection], @@ -1778,8 +1784,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): obj: Optional[Callable] = _objective_decorator(self.objective) else: obj = None - model, metric, params = self._configure_fit( - booster=xgb_model, eval_metric=eval_metric, params=params + model, metric, params, early_stopping_rounds = self._configure_fit( + xgb_model, eval_metric, params, early_stopping_rounds ) results = await self.client.sync( _train_async, @@ -1791,7 +1797,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): num_boost_round=self.get_num_boosting_rounds(), evals=evals, obj=obj, - feval=metric, + feval=None, + custom_metric=metric, verbose_eval=verbose, early_stopping_rounds=early_stopping_rounds, callbacks=callbacks, @@ -1832,9 +1839,14 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): base_margin: Optional[_DaskCollection], iteration_range: Optional[Tuple[int, int]], ) -> _DaskCollection: + if self.objective == "multi:softmax": + raise ValueError( + "multi:softmax doesn't support `predict_proba`. " + "Switch to `multi:softproba` instead" + ) predts = await super()._predict_async( data=X, - output_margin=self.objective == "multi:softmax", + output_margin=False, validate_features=validate_features, base_margin=base_margin, iteration_range=iteration_range, @@ -1903,9 +1915,9 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): """, ["estimators", "model"], end_note=""" - Note - ---- - For dask implementation, group is not supported, use qid instead. + .. note:: + + For dask implementation, group is not supported, use qid instead. """, ) class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn): @@ -1929,7 +1941,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn): eval_group: Optional[List[_DaskCollection]], eval_qid: Optional[List[_DaskCollection]], eval_metric: Optional[Union[str, List[str], Metric]], - early_stopping_rounds: int, + early_stopping_rounds: Optional[int], verbose: bool, xgb_model: Optional[Union[XGBModel, Booster]], feature_weights: Optional[_DaskCollection], @@ -1963,8 +1975,8 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn): raise ValueError( "Custom evaluation metric is not yet supported for XGBRanker." ) - model, metric, params = self._configure_fit( - booster=xgb_model, eval_metric=eval_metric, params=params + model, metric, params, early_stopping_rounds = self._configure_fit( + xgb_model, eval_metric, params, early_stopping_rounds ) results = await self.client.sync( _train_async, @@ -1976,7 +1988,8 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn): num_boost_round=self.get_num_boosting_rounds(), evals=evals, obj=None, - feval=metric, + feval=None, + custom_metric=metric, verbose_eval=verbose, early_stopping_rounds=early_stopping_rounds, callbacks=callbacks, diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index feea9c5b3..dd877cfc0 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -89,6 +89,19 @@ def _objective_decorator( return inner +def _metric_decorator(func: Callable) -> Metric: + """Decorate a metric function from sklearn. + + Converts an metric function that uses the typical sklearn metric signature so that it + is compatible with :py:func:`train` + + """ + def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]: + y_true = dmatrix.get_label() + return func.__name__, func(y_true, y_score) + return inner + + __estimator_doc = ''' n_estimators : int Number of gradient boosted trees. Equivalent to number of boosting @@ -183,6 +196,66 @@ __model_doc = f''' Experimental support for categorical data. Do not set to true unless you are interested in development. Only valid when `gpu_hist` and dataframe are used. + eval_metric : Optional[Union[str, List[str], Callable]] + + .. versionadded:: 1.5.1 + + Metric used for monitoring the training result and early stopping. It can be a + string or list of strings as names of predefined metric in XGBoost (See + doc/parameter.rst), one of the metrics in :py:mod:`sklearn.metrics`, or any other + user defined metric that looks like `sklearn.metrics`. + + If custom objective is also provided, then custom metric should implement the + corresponding reverse link function. + + Unlike the `scoring` parameter commonly used in scikit-learn, when a callable + object is provided, it's assumed to be a cost function and by default XGBoost will + minimize the result during early stopping. + + For advanced usage on Early stopping like directly choosing to maximize instead of + minimize, see :py:obj:`xgboost.callback.EarlyStopping`. + + See `Custom Objective and Evaluation Metric + `_ for + more. + + .. note:: + + This parameter replaces `eval_metric` in :py:meth:`fit` method. The old one + receives un-transformed prediction regardless of whether custom objective is + being used. + + .. code-block:: python + + from sklearn.datasets import load_diabetes + from sklearn.metrics import mean_absolute_error + X, y = load_diabetes(return_X_y=True) + reg = xgb.XGBRegressor( + tree_method="hist", + eval_metric=mean_absolute_error, + ) + reg.fit(X, y, eval_set=[(X, y)]) + + early_stopping_rounds : Optional[int] + + .. versionadded:: 1.5.1 + + Activates early stopping. Validation metric needs to improve at least once in + every **early_stopping_rounds** round(s) to continue training. Requires at least + one item in **eval_set** in :py:meth:`xgboost.sklearn.XGBModel.fit`. + + The method returns the model from the last iteration (not the best one). If + there's more than one item in **eval_set**, the last entry will be used for early + stopping. If there's more than one metric in **eval_metric**, the last metric + will be used for early stopping. + + If early stopping occurs, the model will have three additional fields: + ``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``. + + .. note:: + + This parameter replaces `early_stopping_rounds` in :py:meth:`fit` method. + kwargs : dict, optional Keyword arguments for XGBoost Booster object. Full documentation of parameters can be found here: @@ -397,6 +470,8 @@ class XGBModel(XGBModelBase): validate_parameters: Optional[bool] = None, predictor: Optional[str] = None, enable_categorical: bool = False, + eval_metric: Optional[Union[str, List[str], Callable]] = None, + early_stopping_rounds: Optional[int] = None, **kwargs: Any ) -> None: if not SKLEARN_INSTALLED: @@ -433,6 +508,8 @@ class XGBModel(XGBModelBase): self.validate_parameters = validate_parameters self.predictor = predictor self.enable_categorical = enable_categorical + self.eval_metric = eval_metric + self.early_stopping_rounds = early_stopping_rounds if kwargs: self.kwargs = kwargs @@ -543,8 +620,13 @@ class XGBModel(XGBModelBase): params = self.get_params() # Parameters that should not go into native learner. wrapper_specific = { - 'importance_type', 'kwargs', 'missing', 'n_estimators', 'use_label_encoder', - "enable_categorical" + "importance_type", + "kwargs", + "missing", + "n_estimators", + "use_label_encoder", + "enable_categorical", + "early_stopping_rounds", } filtered = {} for k, v in params.items(): @@ -629,32 +711,80 @@ class XGBModel(XGBModelBase): load_model.__doc__ = f"""{Booster.load_model.__doc__}""" + # pylint: disable=too-many-branches def _configure_fit( self, booster: Optional[Union[Booster, "XGBModel", str]], eval_metric: Optional[Union[Callable, str, List[str]]], params: Dict[str, Any], - ) -> Tuple[Optional[Union[Booster, str]], Optional[Metric], Dict[str, Any]]: - # pylint: disable=protected-access, no-self-use + early_stopping_rounds: Optional[int], + ) -> Tuple[ + Optional[Union[Booster, str, "XGBModel"]], + Optional[Metric], + Dict[str, Any], + Optional[int], + ]: + """Configure parameters for :py:meth:`fit`.""" if isinstance(booster, XGBModel): - # Handle the case when xgb_model is a sklearn model object - model: Optional[Union[Booster, str]] = booster._Booster + model: Optional[Union[Booster, str]] = booster.get_booster() else: model = booster - feval = eval_metric if callable(eval_metric) else None + def _deprecated(parameter: str) -> None: + warnings.warn( + f"`{parameter}` in `fit` method is deprecated for better compatibility " + f"with scikit-learn, use `{parameter}` in constructor or`set_params` " + "instead.", + UserWarning, + ) + + def _duplicated(parameter: str) -> None: + raise ValueError( + f"2 different `{parameter}` are provided. Use the one in constructor " + "or `set_params` instead." + ) + + # Configure evaluation metric. if eval_metric is not None: - if callable(eval_metric): - eval_metric = None + _deprecated("eval_metric") + if self.eval_metric is not None and eval_metric is not None: + _duplicated("eval_metric") + # - track where does the evaluation metric come from + if self.eval_metric is not None: + from_fit = False + eval_metric = self.eval_metric + else: + from_fit = True + # - configure callable evaluation metric + metric: Optional[Metric] = None + if eval_metric is not None: + if callable(eval_metric) and from_fit: + # No need to wrap the evaluation function for old parameter. + metric = eval_metric + elif callable(eval_metric): + # Parameter from constructor or set_params + metric = _metric_decorator(eval_metric) else: params.update({"eval_metric": eval_metric}) + + # Configure early_stopping_rounds + if early_stopping_rounds is not None: + _deprecated("early_stopping_rounds") + if early_stopping_rounds is not None and self.early_stopping_rounds is not None: + _duplicated("early_stopping_rounds") + early_stopping_rounds = ( + self.early_stopping_rounds + if self.early_stopping_rounds is not None + else early_stopping_rounds + ) + if self.enable_categorical and params.get("tree_method", None) != "gpu_hist": raise ValueError( "Experimental support for categorical data is not implemented for" " current tree method yet." ) - return model, feval, params + return model, metric, params, early_stopping_rounds def _set_evaluation_result(self, evals_result: TrainingCallback.EvalsLog) -> None: if evals_result: @@ -702,31 +832,15 @@ class XGBModel(XGBModelBase): A list of (X, y) tuple pairs to use as validation sets, for which metrics will be computed. Validation metrics will help us track the performance of the model. - eval_metric : - If a str, should be a built-in evaluation metric to use. See doc/parameter.rst. - If a list of str, should be the list of multiple built-in evaluation metrics - to use. + eval_metric : str, list of str, or callable, optional + .. deprecated:: 1.5.1 + Use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead. - If callable, a custom evaluation metric. The call signature is - ``func(y_predicted, y_true)`` where ``y_true`` will be a DMatrix object such - that you may need to call the ``get_label`` method. It must return a str, - value pair where the str is a name for the evaluation and value is the value - of the evaluation function. The callable custom objective is always minimized. - early_stopping_rounds : - Activates early stopping. Validation metric needs to improve at least once in - every **early_stopping_rounds** round(s) to continue training. - Requires at least one item in **eval_set**. - - The method returns the model from the last iteration (not the best one). - If there's more than one item in **eval_set**, the last entry will be used - for early stopping. - - If there's more than one metric in **eval_metric**, the last metric will be - used for early stopping. - - If early stopping occurs, the model will have three additional fields: - ``clf.best_score``, ``clf.best_iteration``. + early_stopping_rounds : int + .. deprecated:: 1.5.1 + Use `early_stopping_rounds` in :py:meth:`__init__` or + :py:meth:`set_params` instead. verbose : If `verbose` and an evaluation set is used, writes the evaluation metric measured on the validation set to stderr. @@ -783,7 +897,9 @@ class XGBModel(XGBModelBase): else: obj = None - model, feval, params = self._configure_fit(xgb_model, eval_metric, params) + model, metric, params, early_stopping_rounds = self._configure_fit( + xgb_model, eval_metric, params, early_stopping_rounds + ) self._Booster = train( params, train_dmatrix, @@ -792,7 +908,7 @@ class XGBModel(XGBModelBase): early_stopping_rounds=early_stopping_rounds, evals_result=evals_result, obj=obj, - feval=feval, + custom_metric=metric, verbose_eval=verbose, xgb_model=model, callbacks=callbacks, @@ -1185,12 +1301,14 @@ class XGBClassifier(XGBModel, XGBClassifierBase): obj = None if self.n_classes_ > 2: - # Switch to using a multiclass objective in the underlying - # XGB instance - params["objective"] = "multi:softprob" + # Switch to using a multiclass objective in the underlying XGB instance + if params.get("objective", None) != "multi:softmax": + params["objective"] = "multi:softprob" params["num_class"] = self.n_classes_ - model, feval, params = self._configure_fit(xgb_model, eval_metric, params) + model, metric, params, early_stopping_rounds = self._configure_fit( + xgb_model, eval_metric, params, early_stopping_rounds + ) train_dmatrix, evals = _wrap_evaluation_matrices( missing=self.missing, X=X, @@ -1217,7 +1335,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase): early_stopping_rounds=early_stopping_rounds, evals_result=evals_result, obj=obj, - feval=feval, + custom_metric=metric, verbose_eval=verbose, xgb_model=model, callbacks=callbacks, @@ -1304,12 +1422,19 @@ class XGBClassifier(XGBModel, XGBClassifierBase): """ # custom obj: Do nothing as we don't know what to do. # softprob: Do nothing, output is proba. - # softmax: Use output margin to remove the argmax in PredTransform. + # softmax: Unsupported by predict_proba() # binary:logistic: Expand the prob vector into 2-class matrix after predict. # binary:logitraw: Unsupported by predict_proba() + if self.objective == "multi:softmax": + # We need to run a Python implementation of softmax for it. Just ask user to + # use softprob since XGBoost's implementation has mitigation for floating + # point overflow. No need to reinvent the wheel. + raise ValueError( + "multi:softmax doesn't support `predict_proba`. " + "Switch to `multi:softproba` instead" + ) class_probs = super().predict( X=X, - output_margin=self.objective == "multi:softmax", ntree_limit=ntree_limit, validate_features=validate_features, base_margin=base_margin, @@ -1325,8 +1450,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase): If **eval_set** is passed to the `fit` function, you can call ``evals_result()`` to get evaluation results for all passed **eval_sets**. - When **eval_metric** is also passed to the `fit` function, the - **evals_result** will contain the **eval_metrics** passed to the `fit` function. + + When **eval_metric** is also passed as a parameter, the **evals_result** will + contain the **eval_metric** passed to the `fit` function. Returns ------- @@ -1337,13 +1463,14 @@ class XGBClassifier(XGBModel, XGBClassifierBase): .. code-block:: python - param_dist = {'objective':'binary:logistic', 'n_estimators':2} + param_dist = { + 'objective':'binary:logistic', 'n_estimators':2, eval_metric="logloss" + } clf = xgb.XGBClassifier(**param_dist) clf.fit(X_train, y_train, eval_set=[(X_train, y_train), (X_test, y_test)], - eval_metric='logloss', verbose=True) evals_result = clf.evals_result() @@ -1354,6 +1481,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase): {'validation_0': {'logloss': ['0.604835', '0.531479']}, 'validation_1': {'logloss': ['0.41965', '0.17686']}} + """ if self.evals_result_: evals_result = self.evals_result_ @@ -1386,6 +1514,7 @@ class XGBRFClassifier(XGBClassifier): colsample_bynode=colsample_bynode, reg_lambda=reg_lambda, **kwargs) + _check_rf_callback(self.early_stopping_rounds, None) def get_xgb_params(self) -> Dict[str, Any]: params = super().get_xgb_params() @@ -1457,6 +1586,7 @@ class XGBRFRegressor(XGBRegressor): reg_lambda=reg_lambda, **kwargs ) + _check_rf_callback(self.early_stopping_rounds, None) def get_xgb_params(self) -> Dict[str, Any]: params = super().get_xgb_params() @@ -1495,15 +1625,15 @@ class XGBRFRegressor(XGBRegressor): 'Implementation of the Scikit-Learn API for XGBoost Ranking.', ['estimators', 'model'], end_note=''' - Note - ---- - A custom objective function is currently not supported by XGBRanker. - Likewise, a custom metric function is not supported either. + .. note:: - Note - ---- - Query group information is required for ranking tasks by either using the `group` - parameter or `qid` parameter in `fit` method. + A custom objective function is currently not supported by XGBRanker. + Likewise, a custom metric function is not supported either. + + .. note:: + + Query group information is required for ranking tasks by either using the + `group` parameter or `qid` parameter in `fit` method. Before fitting the model, your data need to be sorted by query group. When fitting the model, you need to provide an additional array that contains the size of each @@ -1605,22 +1735,16 @@ class XGBRanker(XGBModel, XGBRankerMixIn): eval_qid : A list in which ``eval_qid[i]`` is the array containing query ID of ``i``-th pair in **eval_set**. - eval_metric : - If a str, should be a built-in evaluation metric to use. See - doc/parameter.rst. - If a list of str, should be the list of multiple built-in evaluation metrics - to use. The custom evaluation metric is not yet supported for the ranker. - early_stopping_rounds : - Activates early stopping. Validation metric needs to improve at least once in - every **early_stopping_rounds** round(s) to continue training. Requires at - least one item in **eval_set**. - The method returns the model from the last iteration (not the best one). If - there's more than one item in **eval_set**, the last entry will be used for - early stopping. - If there's more than one metric in **eval_metric**, the last metric will be - used for early stopping. - If early stopping occurs, the model will have three additional fields: - ``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``. + + eval_metric : str, list of str, optional + .. deprecated:: 1.5.1 + use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead. + + early_stopping_rounds : int + .. deprecated:: 1.5.1 + use `early_stopping_rounds` in :py:meth:`__init__` or + :py:meth:`set_params` instead. + verbose : If `verbose` and an evaluation set is used, writes the evaluation metric measured on the validation set to stderr. @@ -1685,8 +1809,10 @@ class XGBRanker(XGBModel, XGBRankerMixIn): evals_result: TrainingCallback.EvalsLog = {} params = self.get_xgb_params() - model, feval, params = self._configure_fit(xgb_model, eval_metric, params) - if callable(feval): + model, metric, params, early_stopping_rounds = self._configure_fit( + xgb_model, eval_metric, params, early_stopping_rounds + ) + if callable(metric): raise ValueError( 'Custom evaluation metric is not yet supported for XGBRanker.' ) @@ -1696,7 +1822,8 @@ class XGBRanker(XGBModel, XGBRankerMixIn): self.n_estimators, early_stopping_rounds=early_stopping_rounds, evals=evals, - evals_result=evals_result, feval=feval, + evals_result=evals_result, + custom_metric=metric, verbose_eval=verbose, xgb_model=model, callbacks=callbacks ) diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index 611a7fbff..2b0035a9a 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -4,9 +4,12 @@ """Training Library containing training routines.""" import copy from typing import Optional, List +import warnings import numpy as np from .core import Booster, XGBoostError, _get_booster_layer_trees +from .core import _deprecate_positional_args +from .core import Objective, Metric from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold) from . import callback @@ -22,21 +25,48 @@ def _assert_new_callback(callbacks: Optional[List[callback.TrainingCallback]]) - ) -def _train_internal(params, dtrain, - num_boost_round=10, evals=(), - obj=None, feval=None, - xgb_model=None, callbacks=None, - evals_result=None, maximize=None, - verbose_eval=None, early_stopping_rounds=None): +def _configure_custom_metric( + feval: Optional[Metric], custom_metric: Optional[Metric] +) -> Optional[Metric]: + if feval is not None: + link = "https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html" + warnings.warn( + "`feval` is deprecated, use `custom_metric` instead. They have " + "different behavior when custom objective is also used." + f"See {link} for details on the `custom_metric`." + ) + if feval is not None and custom_metric is not None: + raise ValueError( + "Bost `feval` and `custom_metric` are supplied. Use `custom_metric` instead." + ) + eval_metric = custom_metric if custom_metric is not None else feval + return eval_metric + + +def _train_internal( + params, + dtrain, + num_boost_round=10, + evals=(), + obj=None, + feval=None, + custom_metric=None, + xgb_model=None, + callbacks=None, + evals_result=None, + maximize=None, + verbose_eval=None, + early_stopping_rounds=None, +): """internal training function""" callbacks = [] if callbacks is None else copy.copy(callbacks) + metric_fn = _configure_custom_metric(feval, custom_metric) evals = list(evals) bst = Booster(params, [dtrain] + [d[0] for d in evals]) if xgb_model is not None: - bst = Booster(params, [dtrain] + [d[0] for d in evals], - model_file=xgb_model) + bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model) start_iteration = 0 @@ -48,7 +78,14 @@ def _train_internal(params, dtrain, callbacks.append( callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize) ) - callbacks = callback.CallbackContainer(callbacks, metric=feval) + callbacks = callback.CallbackContainer( + callbacks, + metric=metric_fn, + # For old `feval` parameter, the behavior is unchanged. For the new + # `custom_metric`, it will receive proper prediction result when custom objective + # is not used. + output_margin=callable(obj) or metric_fn is feval, + ) bst = callbacks.before_training(bst) @@ -89,9 +126,23 @@ def _train_internal(params, dtrain, return bst.copy() -def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, - maximize=None, early_stopping_rounds=None, evals_result=None, - verbose_eval=True, xgb_model=None, callbacks=None): +@_deprecate_positional_args +def train( + params, + dtrain, + num_boost_round=10, + *, + evals=(), + obj: Optional[Objective] = None, + feval=None, + maximize=None, + early_stopping_rounds=None, + evals_result=None, + verbose_eval=True, + xgb_model=None, + callbacks=None, + custom_metric: Optional[Metric] = None, +): # pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init """Train a booster with given parameters. @@ -106,10 +157,13 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, evals: list of pairs (DMatrix, string) List of validation sets for which metrics will evaluated during training. Validation metrics will help us track the performance of the model. - obj : function - Customized objective function. - feval : function - Customized evaluation function. + obj + Custom objective function. See `Custom Objective + `_ for + details. + feval : + .. deprecated:: 1.5.1 + Use `custom_metric` instead. maximize : bool Whether to maximize feval. early_stopping_rounds: int @@ -158,23 +212,37 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, [xgb.callback.LearningRateScheduler(custom_rates)] + custom_metric: + + .. versionadded 1.5.1 + + Custom metric function. See `Custom Metric + `_ for + details. + Returns ------- Booster : a trained booster model """ - bst = _train_internal(params, dtrain, - num_boost_round=num_boost_round, - evals=evals, - obj=obj, feval=feval, - xgb_model=xgb_model, callbacks=callbacks, - verbose_eval=verbose_eval, - evals_result=evals_result, - maximize=maximize, - early_stopping_rounds=early_stopping_rounds) + bst = _train_internal( + params, + dtrain, + num_boost_round=num_boost_round, + evals=evals, + obj=obj, + feval=feval, + xgb_model=xgb_model, + callbacks=callbacks, + verbose_eval=verbose_eval, + evals_result=evals_result, + maximize=maximize, + early_stopping_rounds=early_stopping_rounds, + custom_metric=custom_metric, + ) return bst -class CVPack(object): +class CVPack: """"Auxiliary datastruct to hold one fold of CV.""" def __init__(self, dtrain, dtest, param): """"Initialize the CVPack""" @@ -192,9 +260,9 @@ class CVPack(object): """"Update the boosters for one iteration""" self.bst.update(self.dtrain, iteration, fobj) - def eval(self, iteration, feval): + def eval(self, iteration, feval, output_margin): """"Evaluate the CVPack for one iteration.""" - return self.bst.eval_set(self.watchlist, iteration, feval) + return self.bst.eval_set(self.watchlist, iteration, feval, output_margin) class _PackedBooster: @@ -206,9 +274,9 @@ class _PackedBooster: for fold in self.cvfolds: fold.update(iteration, obj) - def eval(self, iteration, feval): + def eval(self, iteration, feval, output_margin): '''Iterate through folds for eval''' - result = [f.eval(iteration, feval) for f in self.cvfolds] + result = [f.eval(iteration, feval, output_margin) for f in self.cvfolds] return result def set_attr(self, **kwargs): @@ -345,9 +413,10 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False, def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None, - metrics=(), obj=None, feval=None, maximize=None, early_stopping_rounds=None, + metrics=(), obj: Optional[Objective] = None, + feval=None, maximize=None, early_stopping_rounds=None, fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True, - seed=0, callbacks=None, shuffle=True): + seed=0, callbacks=None, shuffle=True, custom_metric: Optional[Metric] = None): # pylint: disable = invalid-name """Cross-validation with given parameters. @@ -372,10 +441,15 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None indices to be used as the testing samples for the ``n`` th fold. metrics : string or list of strings Evaluation metrics to be watched in CV. - obj : function - Custom objective function. + obj : + + Custom objective function. See `Custom Objective + `_ for + details. + feval : function - Custom evaluation function. + .. deprecated:: 1.5.1 + Use `custom_metric` instead. maximize : bool Whether to maximize feval. early_stopping_rounds: int @@ -412,6 +486,13 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None [xgb.callback.LearningRateScheduler(custom_rates)] shuffle : bool Shuffle data before creating folds. + custom_metric : + + .. versionadded 1.5.1 + + Custom metric function. See `Custom Metric + `_ for + details. Returns ------- @@ -443,6 +524,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc, stratified, folds, shuffle) + metric_fn = _configure_custom_metric(feval, custom_metric) + # setup callbacks callbacks = [] if callbacks is None else callbacks _assert_new_callback(callbacks) @@ -456,7 +539,12 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None callbacks.append( callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize) ) - callbacks = callback.CallbackContainer(callbacks, metric=feval, is_cv=True) + callbacks = callback.CallbackContainer( + callbacks, + metric=metric_fn, + is_cv=True, + output_margin=callable(obj) or metric_fn is feval, + ) booster = _PackedBooster(cvfolds) callbacks.before_training(booster) diff --git a/tests/python/test_callback.py b/tests/python/test_callback.py index 3a2d5eecf..93a304e1c 100644 --- a/tests/python/test_callback.py +++ b/tests/python/test_callback.py @@ -173,10 +173,11 @@ class TestCallbacks: def test_early_stopping_skl(self): from sklearn.datasets import load_breast_cancer X, y = load_breast_cancer(return_X_y=True) - cls = xgb.XGBClassifier() early_stopping_rounds = 5 - cls.fit(X, y, eval_set=[(X, y)], - early_stopping_rounds=early_stopping_rounds, eval_metric='error') + cls = xgb.XGBClassifier( + early_stopping_rounds=early_stopping_rounds, eval_metric='error' + ) + cls.fit(X, y, eval_set=[(X, y)]) booster = cls.get_booster() dump = booster.get_dump(dump_format='json') assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 @@ -184,12 +185,10 @@ class TestCallbacks: def test_early_stopping_custom_eval_skl(self): from sklearn.datasets import load_breast_cancer X, y = load_breast_cancer(return_X_y=True) - cls = xgb.XGBClassifier() + cls = xgb.XGBClassifier(eval_metric=tm.eval_error_metric_skl) early_stopping_rounds = 5 early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds) - cls.fit(X, y, eval_set=[(X, y)], - eval_metric=tm.eval_error_metric, - callbacks=[early_stop]) + cls.fit(X, y, eval_set=[(X, y)], callbacks=[early_stop]) booster = cls.get_booster() dump = booster.get_dump(dump_format='json') assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 @@ -198,41 +197,40 @@ class TestCallbacks: from sklearn.datasets import load_breast_cancer X, y = load_breast_cancer(return_X_y=True) n_estimators = 100 - cls = xgb.XGBClassifier(n_estimators=n_estimators) + cls = xgb.XGBClassifier( + n_estimators=n_estimators, eval_metric=tm.eval_error_metric_skl + ) early_stopping_rounds = 5 early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, save_best=True) - cls.fit(X, y, eval_set=[(X, y)], - eval_metric=tm.eval_error_metric, callbacks=[early_stop]) + cls.fit(X, y, eval_set=[(X, y)], callbacks=[early_stop]) booster = cls.get_booster() dump = booster.get_dump(dump_format='json') assert len(dump) == booster.best_iteration + 1 early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, save_best=True) - cls = xgb.XGBClassifier(booster='gblinear', n_estimators=10) + cls = xgb.XGBClassifier( + booster='gblinear', n_estimators=10, eval_metric=tm.eval_error_metric_skl + ) with pytest.raises(ValueError): - cls.fit(X, y, eval_set=[(X, y)], eval_metric=tm.eval_error_metric, - callbacks=[early_stop]) + cls.fit(X, y, eval_set=[(X, y)], callbacks=[early_stop]) # No error early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, save_best=False) - xgb.XGBClassifier(booster='gblinear', n_estimators=10).fit( - X, y, eval_set=[(X, y)], - eval_metric=tm.eval_error_metric, - callbacks=[early_stop]) + xgb.XGBClassifier( + booster='gblinear', n_estimators=10, eval_metric=tm.eval_error_metric_skl + ).fit(X, y, eval_set=[(X, y)], callbacks=[early_stop]) def test_early_stopping_continuation(self): from sklearn.datasets import load_breast_cancer X, y = load_breast_cancer(return_X_y=True) - cls = xgb.XGBClassifier() + cls = xgb.XGBClassifier(eval_metric=tm.eval_error_metric_skl) early_stopping_rounds = 5 early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, save_best=True) - cls.fit(X, y, eval_set=[(X, y)], - eval_metric=tm.eval_error_metric, - callbacks=[early_stop]) + cls.fit(X, y, eval_set=[(X, y)], callbacks=[early_stop]) booster = cls.get_booster() assert booster.num_boosted_rounds() == booster.best_iteration + 1 @@ -243,8 +241,8 @@ class TestCallbacks: cls.load_model(path) assert cls._Booster is not None early_stopping_rounds = 3 - cls.fit(X, y, eval_set=[(X, y)], eval_metric=tm.eval_error_metric, - early_stopping_rounds=early_stopping_rounds) + cls.set_params(eval_metric=tm.eval_error_metric_skl) + cls.fit(X, y, eval_set=[(X, y)], early_stopping_rounds=early_stopping_rounds) booster = cls.get_booster() assert booster.num_boosted_rounds() == \ booster.best_iteration + early_stopping_rounds + 1 diff --git a/tests/python/test_early_stopping.py b/tests/python/test_early_stopping.py index aba4f8c08..29f8fb4b0 100644 --- a/tests/python/test_early_stopping.py +++ b/tests/python/test_early_stopping.py @@ -7,7 +7,6 @@ rng = np.random.RandomState(1994) class TestEarlyStopping: - @pytest.mark.skipif(**tm.no_sklearn()) def test_early_stopping_nonparallel(self): from sklearn.datasets import load_digits diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 033c86451..f506d0d80 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -1663,11 +1663,16 @@ class TestDaskCallbacks: valid_X, valid_y = load_breast_cancer(return_X_y=True) valid_X, valid_y = da.from_array(valid_X), da.from_array(valid_y) - cls = xgb.dask.DaskXGBClassifier(objective='binary:logistic', tree_method='hist', - n_estimators=1000) + cls = xgb.dask.DaskXGBClassifier( + objective='binary:logistic', + tree_method='hist', + n_estimators=1000, + eval_metric=tm.eval_error_metric_skl + ) cls.client = client - cls.fit(X, y, early_stopping_rounds=early_stopping_rounds, - eval_set=[(valid_X, valid_y)], eval_metric=tm.eval_error_metric) + cls.fit( + X, y, early_stopping_rounds=early_stopping_rounds, eval_set=[(valid_X, valid_y)] + ) booster = cls.get_booster() dump = booster.get_dump(dump_format='json') assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 7ca79fecc..2a400871d 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -1271,3 +1271,76 @@ def test_prediction_config(): reg.set_params(booster="gblinear") assert reg._can_use_inplace_predict() is False + + +def test_evaluation_metric(): + from sklearn.datasets import load_diabetes, load_digits + from sklearn.metrics import mean_absolute_error + X, y = load_diabetes(return_X_y=True) + n_estimators = 16 + + with tm.captured_output() as (out, err): + reg = xgb.XGBRegressor( + tree_method="hist", + eval_metric=mean_absolute_error, + n_estimators=n_estimators, + ) + reg.fit(X, y, eval_set=[(X, y)]) + lines = out.getvalue().strip().split('\n') + + assert len(lines) == n_estimators + for line in lines: + assert line.find("mean_absolute_error") != -1 + + def metric(predt: np.ndarray, Xy: xgb.DMatrix): + y = Xy.get_label() + return "m", np.abs(predt - y).sum() + + with pytest.warns(UserWarning): + reg = xgb.XGBRegressor( + tree_method="hist", + n_estimators=1, + ) + reg.fit(X, y, eval_set=[(X, y)], eval_metric=metric) + + def merror(y_true: np.ndarray, predt: np.ndarray): + n_samples = y_true.shape[0] + assert n_samples == predt.size + errors = np.zeros(y_true.shape[0]) + errors[y != predt] = 1.0 + return np.sum(errors) / n_samples + + X, y = load_digits(n_class=10, return_X_y=True) + + clf = xgb.XGBClassifier( + use_label_encoder=False, + tree_method="hist", + eval_metric=merror, + n_estimators=16, + objective="multi:softmax" + ) + clf.fit(X, y, eval_set=[(X, y)]) + custom = clf.evals_result() + + clf = xgb.XGBClassifier( + use_label_encoder=False, + tree_method="hist", + eval_metric="merror", + n_estimators=16, + objective="multi:softmax" + ) + clf.fit(X, y, eval_set=[(X, y)]) + internal = clf.evals_result() + np.testing.assert_allclose( + custom["validation_0"]["merror"], internal["validation_0"]["merror"] + ) + + clf = xgb.XGBRFClassifier( + use_label_encoder=False, + tree_method="hist", n_estimators=16, + objective=tm.softprob_obj(10), + eval_metric=merror, + ) + with pytest.raises(AssertionError): + # shape check inside the `merror` function + clf.fit(X, y, eval_set=[(X, y)]) diff --git a/tests/python/testing.py b/tests/python/testing.py index fe6d9b32c..328cc63a2 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -338,6 +338,7 @@ def non_increasing(L, tolerance=1e-4): def eval_error_metric(predt, dtrain: xgb.DMatrix): + """Evaluation metric for xgb.train""" label = dtrain.get_label() r = np.zeros(predt.shape) gt = predt > 0.5 @@ -349,6 +350,16 @@ def eval_error_metric(predt, dtrain: xgb.DMatrix): return 'CustomErr', np.sum(r) +def eval_error_metric_skl(y_true: np.ndarray, y_score: np.ndarray) -> float: + """Evaluation metric that looks like metrics provided by sklearn.""" + r = np.zeros(y_score.shape) + gt = y_score > 0.5 + r[gt] = 1 - y_true[gt] + le = y_score <= 0.5 + r[le] = y_true[le] + return np.sum(r) + + def softmax(x): e = np.exp(x) return e / np.sum(e)