diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 5b0fc9a30..9557fdd64 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -572,6 +572,13 @@ class XGBModel(XGBModelBase): params.update({"eval_metric": eval_metric}) return model, feval, params + def _set_evaluation_result(self, evals_result: Optional[dict]) -> None: + if evals_result: + for val in evals_result.items(): + evals_result_key = list(val[1].keys())[0] + evals_result[val[0]][evals_result_key] = val[1][evals_result_key] + self.evals_result_ = evals_result + @_deprecate_positional_args def fit(self, X, y, *, sample_weight=None, base_margin=None, eval_set=None, eval_metric=None, early_stopping_rounds=None, @@ -678,13 +685,7 @@ class XGBModel(XGBModelBase): verbose_eval=verbose, xgb_model=model, callbacks=callbacks) - if evals_result: - for val in evals_result.items(): - evals_result_key = list(val[1].keys())[0] - evals_result[val[0]][evals_result_key] = val[1][ - evals_result_key] - self.evals_result_ = evals_result - + self._set_evaluation_result(evals_result) return self def predict(self, data, output_margin=False, ntree_limit=None, @@ -1035,13 +1036,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase): if not callable(self.objective): self.objective = params["objective"] - if evals_result: - for val in evals_result.items(): - evals_result_key = list(val[1].keys())[0] - evals_result[val[0]][ - evals_result_key] = val[1][evals_result_key] - self.evals_result_ = evals_result - + self._set_evaluation_result(evals_result) return self fit.__doc__ = XGBModel.fit.__doc__.replace( @@ -1502,13 +1497,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn): callbacks=callbacks) self.objective = params["objective"] - - if evals_result: - for val in evals_result.items(): - evals_result_key = list(val[1].keys())[0] - evals_result[val[0]][evals_result_key] = val[1][evals_result_key] - self.evals_result = evals_result - + self._set_evaluation_result(evals_result) return self def predict(self, data, output_margin=False, diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 2c40e6c3d..5e2adc7c1 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -138,6 +138,7 @@ def test_ranking(): model = xgb.sklearn.XGBRanker(**params) model.fit(x_train, y_train, group=train_group, eval_set=[(x_valid, y_valid)], eval_group=[valid_group]) + assert model.evals_result() pred = model.predict(x_test)