Calling XGBModel.fit() should clear the Booster by default (#6562)
* Calling XGBModel.fit() should clear the Booster by default * Document the behavior of fit() * Allow sklearn object to be passed in directly via xgb_model argument * Fix lint
This commit is contained in:
parent
5e9e525223
commit
fa13992264
@ -501,8 +501,10 @@ class XGBModel(XGBModelBase):
|
||||
eval_metric: Optional[Union[Callable, str, List[str]]],
|
||||
params: Dict[str, Any],
|
||||
) -> Tuple[Booster, Optional[Union[Callable, str, List[str]]], Dict[str, Any]]:
|
||||
model = self._Booster if hasattr(self, "_Booster") else None
|
||||
model = booster if booster is not None else model
|
||||
# pylint: disable=protected-access, no-self-use
|
||||
model = booster
|
||||
if hasattr(model, '_Booster'):
|
||||
model = model._Booster # Handle the case when xgb_model is a sklearn model object
|
||||
feval = eval_metric if callable(eval_metric) else None
|
||||
if eval_metric is not None:
|
||||
if callable(eval_metric):
|
||||
@ -518,7 +520,11 @@ class XGBModel(XGBModelBase):
|
||||
feature_weights=None,
|
||||
callbacks=None):
|
||||
# pylint: disable=invalid-name,attribute-defined-outside-init
|
||||
"""Fit gradient boosting model
|
||||
"""Fit gradient boosting model.
|
||||
|
||||
Note that calling ``fit()`` multiple times will cause the model object to be re-fit from
|
||||
scratch. To resume training from a previous checkpoint, explicitly pass ``xgb_model``
|
||||
argument.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -1212,6 +1218,10 @@ class XGBRanker(XGBModel):
|
||||
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||
"""Fit gradient boosting ranker
|
||||
|
||||
Note that calling ``fit()`` multiple times will cause the model object to be re-fit from
|
||||
scratch. To resume training from a previous checkpoint, explicitly pass ``xgb_model``
|
||||
argument.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array_like
|
||||
@ -1322,6 +1332,9 @@ class XGBRanker(XGBModel):
|
||||
raise ValueError(
|
||||
'Custom evaluation metric is not yet supported for XGBRanker.')
|
||||
params.update({'eval_metric': eval_metric})
|
||||
if hasattr(xgb_model, '_Booster'):
|
||||
# Handle the case when xgb_model is a sklearn model object
|
||||
xgb_model = xgb_model._Booster # pylint: disable=protected-access
|
||||
|
||||
self._Booster = train(params, train_dmatrix,
|
||||
self.n_estimators,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user