Calling XGBModel.fit() should clear the Booster by default (#6562)

* Calling XGBModel.fit() should clear the Booster by default

* Document the behavior of fit()

* Allow sklearn object to be passed in directly via xgb_model argument

* Fix lint
This commit is contained in:
Philip Hyunsu Cho 2020-12-31 11:02:08 -08:00 committed by GitHub
parent 5e9e525223
commit fa13992264
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -501,8 +501,10 @@ class XGBModel(XGBModelBase):
eval_metric: Optional[Union[Callable, str, List[str]]], eval_metric: Optional[Union[Callable, str, List[str]]],
params: Dict[str, Any], params: Dict[str, Any],
) -> Tuple[Booster, Optional[Union[Callable, str, List[str]]], Dict[str, Any]]: ) -> Tuple[Booster, Optional[Union[Callable, str, List[str]]], Dict[str, Any]]:
model = self._Booster if hasattr(self, "_Booster") else None # pylint: disable=protected-access, no-self-use
model = booster if booster is not None else model model = booster
if hasattr(model, '_Booster'):
model = model._Booster # Handle the case when xgb_model is a sklearn model object
feval = eval_metric if callable(eval_metric) else None feval = eval_metric if callable(eval_metric) else None
if eval_metric is not None: if eval_metric is not None:
if callable(eval_metric): if callable(eval_metric):
@ -518,7 +520,11 @@ class XGBModel(XGBModelBase):
feature_weights=None, feature_weights=None,
callbacks=None): callbacks=None):
# pylint: disable=invalid-name,attribute-defined-outside-init # pylint: disable=invalid-name,attribute-defined-outside-init
"""Fit gradient boosting model """Fit gradient boosting model.
Note that calling ``fit()`` multiple times will cause the model object to be re-fit from
scratch. To resume training from a previous checkpoint, explicitly pass ``xgb_model``
argument.
Parameters Parameters
---------- ----------
@ -1212,6 +1218,10 @@ class XGBRanker(XGBModel):
# pylint: disable = attribute-defined-outside-init,arguments-differ # pylint: disable = attribute-defined-outside-init,arguments-differ
"""Fit gradient boosting ranker """Fit gradient boosting ranker
Note that calling ``fit()`` multiple times will cause the model object to be re-fit from
scratch. To resume training from a previous checkpoint, explicitly pass ``xgb_model``
argument.
Parameters Parameters
---------- ----------
X : array_like X : array_like
@ -1322,6 +1332,9 @@ class XGBRanker(XGBModel):
raise ValueError( raise ValueError(
'Custom evaluation metric is not yet supported for XGBRanker.') 'Custom evaluation metric is not yet supported for XGBRanker.')
params.update({'eval_metric': eval_metric}) params.update({'eval_metric': eval_metric})
if hasattr(xgb_model, '_Booster'):
# Handle the case when xgb_model is a sklearn model object
xgb_model = xgb_model._Booster # pylint: disable=protected-access
self._Booster = train(params, train_dmatrix, self._Booster = train(params, train_dmatrix,
self.n_estimators, self.n_estimators,