Calling XGBModel.fit() should clear the Booster by default (#6562)
* Calling XGBModel.fit() should clear the Booster by default * Document the behavior of fit() * Allow sklearn object to be passed in directly via xgb_model argument * Fix lint
This commit is contained in:
parent
5e9e525223
commit
fa13992264
@ -501,8 +501,10 @@ class XGBModel(XGBModelBase):
|
|||||||
eval_metric: Optional[Union[Callable, str, List[str]]],
|
eval_metric: Optional[Union[Callable, str, List[str]]],
|
||||||
params: Dict[str, Any],
|
params: Dict[str, Any],
|
||||||
) -> Tuple[Booster, Optional[Union[Callable, str, List[str]]], Dict[str, Any]]:
|
) -> Tuple[Booster, Optional[Union[Callable, str, List[str]]], Dict[str, Any]]:
|
||||||
model = self._Booster if hasattr(self, "_Booster") else None
|
# pylint: disable=protected-access, no-self-use
|
||||||
model = booster if booster is not None else model
|
model = booster
|
||||||
|
if hasattr(model, '_Booster'):
|
||||||
|
model = model._Booster # Handle the case when xgb_model is a sklearn model object
|
||||||
feval = eval_metric if callable(eval_metric) else None
|
feval = eval_metric if callable(eval_metric) else None
|
||||||
if eval_metric is not None:
|
if eval_metric is not None:
|
||||||
if callable(eval_metric):
|
if callable(eval_metric):
|
||||||
@ -518,7 +520,11 @@ class XGBModel(XGBModelBase):
|
|||||||
feature_weights=None,
|
feature_weights=None,
|
||||||
callbacks=None):
|
callbacks=None):
|
||||||
# pylint: disable=invalid-name,attribute-defined-outside-init
|
# pylint: disable=invalid-name,attribute-defined-outside-init
|
||||||
"""Fit gradient boosting model
|
"""Fit gradient boosting model.
|
||||||
|
|
||||||
|
Note that calling ``fit()`` multiple times will cause the model object to be re-fit from
|
||||||
|
scratch. To resume training from a previous checkpoint, explicitly pass ``xgb_model``
|
||||||
|
argument.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -1212,6 +1218,10 @@ class XGBRanker(XGBModel):
|
|||||||
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||||
"""Fit gradient boosting ranker
|
"""Fit gradient boosting ranker
|
||||||
|
|
||||||
|
Note that calling ``fit()`` multiple times will cause the model object to be re-fit from
|
||||||
|
scratch. To resume training from a previous checkpoint, explicitly pass ``xgb_model``
|
||||||
|
argument.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
X : array_like
|
X : array_like
|
||||||
@ -1322,6 +1332,9 @@ class XGBRanker(XGBModel):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Custom evaluation metric is not yet supported for XGBRanker.')
|
'Custom evaluation metric is not yet supported for XGBRanker.')
|
||||||
params.update({'eval_metric': eval_metric})
|
params.update({'eval_metric': eval_metric})
|
||||||
|
if hasattr(xgb_model, '_Booster'):
|
||||||
|
# Handle the case when xgb_model is a sklearn model object
|
||||||
|
xgb_model = xgb_model._Booster # pylint: disable=protected-access
|
||||||
|
|
||||||
self._Booster = train(params, train_dmatrix,
|
self._Booster = train(params, train_dmatrix,
|
||||||
self.n_estimators,
|
self.n_estimators,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user