diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 9560ff2f1..4483ffa0b 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -1680,8 +1680,8 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase): obj: Optional[Callable] = _objective_decorator(self.objective) else: obj = None - model, metric, params, early_stopping_rounds = self._configure_fit( - xgb_model, eval_metric, params, early_stopping_rounds + model, metric, params, early_stopping_rounds, callbacks = self._configure_fit( + xgb_model, eval_metric, params, early_stopping_rounds, callbacks ) results = await self.client.sync( _train_async, @@ -1783,8 +1783,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): obj: Optional[Callable] = _objective_decorator(self.objective) else: obj = None - model, metric, params, early_stopping_rounds = self._configure_fit( - xgb_model, eval_metric, params, early_stopping_rounds + model, metric, params, early_stopping_rounds, callbacks = self._configure_fit( + xgb_model, eval_metric, params, early_stopping_rounds, callbacks ) results = await self.client.sync( _train_async, @@ -1974,8 +1974,8 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn): raise ValueError( "Custom evaluation metric is not yet supported for XGBRanker." ) - model, metric, params, early_stopping_rounds = self._configure_fit( - xgb_model, eval_metric, params, early_stopping_rounds + model, metric, params, early_stopping_rounds, callbacks = self._configure_fit( + xgb_model, eval_metric, params, early_stopping_rounds, callbacks ) results = await self.client.sync( _train_async, diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 7313351fd..d66bf077d 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -257,6 +257,16 @@ __model_doc = f''' This parameter replaces `early_stopping_rounds` in :py:meth:`fit` method. + callbacks : Optional[List[TrainingCallback]] + List of callback functions that are applied at end of each iteration. + It is possible to use predefined callbacks by using :ref:`callback_api`. + Example: + + .. code-block:: python + + callbacks = [xgb.callback.EarlyStopping(rounds=early_stopping_rounds, + save_best=True)] + kwargs : dict, optional Keyword arguments for XGBoost Booster object. Full documentation of parameters can be found here: @@ -473,6 +483,7 @@ class XGBModel(XGBModelBase): enable_categorical: bool = False, eval_metric: Optional[Union[str, List[str], Callable]] = None, early_stopping_rounds: Optional[int] = None, + callbacks: Optional[List[TrainingCallback]] = None, **kwargs: Any ) -> None: if not SKLEARN_INSTALLED: @@ -511,6 +522,7 @@ class XGBModel(XGBModelBase): self.enable_categorical = enable_categorical self.eval_metric = eval_metric self.early_stopping_rounds = early_stopping_rounds + self.callbacks = callbacks if kwargs: self.kwargs = kwargs @@ -628,6 +640,7 @@ class XGBModel(XGBModelBase): "use_label_encoder", "enable_categorical", "early_stopping_rounds", + "callbacks", } filtered = {} for k, v in params.items(): @@ -719,11 +732,13 @@ class XGBModel(XGBModelBase): eval_metric: Optional[Union[Callable, str, Sequence[str]]], params: Dict[str, Any], early_stopping_rounds: Optional[int], + callbacks: Optional[Sequence[TrainingCallback]], ) -> Tuple[ Optional[Union[Booster, str, "XGBModel"]], Optional[Metric], Dict[str, Any], Optional[int], + Optional[Sequence[TrainingCallback]], ]: """Configure parameters for :py:meth:`fit`.""" if isinstance(booster, XGBModel): @@ -779,13 +794,21 @@ class XGBModel(XGBModelBase): else early_stopping_rounds ) + # Configure callbacks + if callbacks is not None: + _deprecated("callbacks") + if callbacks is not None and self.callbacks is not None: + _duplicated("callbacks") + callbacks = self.callbacks if self.callbacks is not None else callbacks + + # lastly check categorical data support. if self.enable_categorical and params.get("tree_method", None) != "gpu_hist": raise ValueError( "Experimental support for categorical data is not implemented for" " current tree method yet." ) - return model, metric, params, early_stopping_rounds + return model, metric, params, early_stopping_rounds, callbacks def _set_evaluation_result(self, evals_result: TrainingCallback.EvalsLog) -> None: if evals_result: @@ -856,16 +879,10 @@ class XGBModel(XGBModelBase): selected when colsample is being used. All values must be greater than 0, otherwise a `ValueError` is thrown. Only available for `hist`, `gpu_hist` and `exact` tree methods. + callbacks : - List of callback functions that are applied at end of each iteration. - It is possible to use predefined callbacks by using :ref:`callback_api`. - Example: - - .. code-block:: python - - callbacks = [xgb.callback.EarlyStopping(rounds=early_stopping_rounds, - save_best=True)] - + .. deprecated: 1.5.1 + Use `callbacks` in :py:meth:`__init__` or :py:methd:`set_params` instead. """ evals_result: TrainingCallback.EvalsLog = {} train_dmatrix, evals = _wrap_evaluation_matrices( @@ -895,8 +912,8 @@ class XGBModel(XGBModelBase): else: obj = None - model, metric, params, early_stopping_rounds = self._configure_fit( - xgb_model, eval_metric, params, early_stopping_rounds + model, metric, params, early_stopping_rounds, callbacks = self._configure_fit( + xgb_model, eval_metric, params, early_stopping_rounds, callbacks ) self._Booster = train( params, @@ -1290,8 +1307,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase): params["objective"] = "multi:softprob" params["num_class"] = self.n_classes_ - model, metric, params, early_stopping_rounds = self._configure_fit( - xgb_model, eval_metric, params, early_stopping_rounds + model, metric, params, early_stopping_rounds, callbacks = self._configure_fit( + xgb_model, eval_metric, params, early_stopping_rounds, callbacks ) train_dmatrix, evals = _wrap_evaluation_matrices( missing=self.missing, @@ -1453,7 +1470,7 @@ class XGBRFClassifier(XGBClassifier): colsample_bynode=colsample_bynode, reg_lambda=reg_lambda, **kwargs) - _check_rf_callback(self.early_stopping_rounds, None) + _check_rf_callback(self.early_stopping_rounds, self.callbacks) def get_xgb_params(self) -> Dict[str, Any]: params = super().get_xgb_params() @@ -1525,7 +1542,7 @@ class XGBRFRegressor(XGBRegressor): reg_lambda=reg_lambda, **kwargs ) - _check_rf_callback(self.early_stopping_rounds, None) + _check_rf_callback(self.early_stopping_rounds, self.callbacks) def get_xgb_params(self) -> Dict[str, Any]: params = super().get_xgb_params() @@ -1708,16 +1725,10 @@ class XGBRanker(XGBModel, XGBRankerMixIn): selected when colsample is being used. All values must be greater than 0, otherwise a `ValueError` is thrown. Only available for `hist`, `gpu_hist` and `exact` tree methods. + callbacks : - List of callback functions that are applied at end of each - iteration. It is possible to use predefined callbacks by using - :ref:`callback_api`. Example: - - .. code-block:: python - - callbacks = [xgb.callback.EarlyStopping(rounds=early_stopping_rounds, - save_best=True)] - + .. deprecated: 1.5.1 + Use `callbacks` in :py:meth:`__init__` or :py:methd:`set_params` instead. """ # check if group information is provided if group is None and qid is None: @@ -1748,8 +1759,8 @@ class XGBRanker(XGBModel, XGBRankerMixIn): evals_result: TrainingCallback.EvalsLog = {} params = self.get_xgb_params() - model, metric, params, early_stopping_rounds = self._configure_fit( - xgb_model, eval_metric, params, early_stopping_rounds + model, metric, params, early_stopping_rounds, callbacks = self._configure_fit( + xgb_model, eval_metric, params, early_stopping_rounds, callbacks ) if callable(metric): raise ValueError( @@ -1757,8 +1768,9 @@ class XGBRanker(XGBModel, XGBRankerMixIn): ) self._Booster = train( - params, train_dmatrix, - self.n_estimators, + params, + train_dmatrix, + self.get_num_boosting_rounds(), early_stopping_rounds=early_stopping_rounds, evals=evals, evals_result=evals_result, diff --git a/tests/python/test_callback.py b/tests/python/test_callback.py index 93a304e1c..08ba9ee79 100644 --- a/tests/python/test_callback.py +++ b/tests/python/test_callback.py @@ -185,10 +185,12 @@ class TestCallbacks: def test_early_stopping_custom_eval_skl(self): from sklearn.datasets import load_breast_cancer X, y = load_breast_cancer(return_X_y=True) - cls = xgb.XGBClassifier(eval_metric=tm.eval_error_metric_skl) early_stopping_rounds = 5 early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds) - cls.fit(X, y, eval_set=[(X, y)], callbacks=[early_stop]) + cls = xgb.XGBClassifier( + eval_metric=tm.eval_error_metric_skl, callbacks=[early_stop] + ) + cls.fit(X, y, eval_set=[(X, y)]) booster = cls.get_booster() dump = booster.get_dump(dump_format='json') assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 @@ -197,13 +199,15 @@ class TestCallbacks: from sklearn.datasets import load_breast_cancer X, y = load_breast_cancer(return_X_y=True) n_estimators = 100 - cls = xgb.XGBClassifier( - n_estimators=n_estimators, eval_metric=tm.eval_error_metric_skl - ) early_stopping_rounds = 5 early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, save_best=True) - cls.fit(X, y, eval_set=[(X, y)], callbacks=[early_stop]) + cls = xgb.XGBClassifier( + n_estimators=n_estimators, + eval_metric=tm.eval_error_metric_skl, + callbacks=[early_stop] + ) + cls.fit(X, y, eval_set=[(X, y)]) booster = cls.get_booster() dump = booster.get_dump(dump_format='json') assert len(dump) == booster.best_iteration + 1 @@ -228,9 +232,12 @@ class TestCallbacks: X, y = load_breast_cancer(return_X_y=True) cls = xgb.XGBClassifier(eval_metric=tm.eval_error_metric_skl) early_stopping_rounds = 5 - early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, - save_best=True) - cls.fit(X, y, eval_set=[(X, y)], callbacks=[early_stop]) + early_stop = xgb.callback.EarlyStopping( + rounds=early_stopping_rounds, save_best=True + ) + with pytest.warns(UserWarning): + cls.fit(X, y, eval_set=[(X, y)], callbacks=[early_stop]) + booster = cls.get_booster() assert booster.num_boosted_rounds() == booster.best_iteration + 1 @@ -247,6 +254,19 @@ class TestCallbacks: assert booster.num_boosted_rounds() == \ booster.best_iteration + early_stopping_rounds + 1 + def test_deprecated(self): + from sklearn.datasets import load_breast_cancer + X, y = load_breast_cancer(return_X_y=True) + early_stopping_rounds = 5 + early_stop = xgb.callback.EarlyStopping( + rounds=early_stopping_rounds, save_best=True + ) + clf = xgb.XGBClassifier( + eval_metric=tm.eval_error_metric_skl, callbacks=[early_stop] + ) + with pytest.raises(ValueError, match=r".*set_params.*"): + clf.fit(X, y, eval_set=[(X, y)], callbacks=[early_stop]) + def run_eta_decay(self, tree_method): """Test learning rate scheduler, used by both CPU and GPU tests.""" scheduler = xgb.callback.LearningRateScheduler