Implement __sklearn_is_fitted__. (#7230)
This commit is contained in:
@@ -435,6 +435,9 @@ class XGBModel(XGBModelBase):
|
||||
'''Tags used for scikit-learn data validation.'''
|
||||
return {'allow_nan': True, 'no_validation': True}
|
||||
|
||||
def __sklearn_is_fitted__(self) -> bool:
|
||||
return hasattr(self, "_Booster")
|
||||
|
||||
def get_booster(self) -> Booster:
|
||||
"""Get the underlying xgboost Booster of this model.
|
||||
|
||||
@@ -444,7 +447,7 @@ class XGBModel(XGBModelBase):
|
||||
-------
|
||||
booster : a xgboost booster of underlying model
|
||||
"""
|
||||
if not hasattr(self, '_Booster'):
|
||||
if not self.__sklearn_is_fitted__():
|
||||
from sklearn.exceptions import NotFittedError
|
||||
raise NotFittedError('need to call fit or load_model beforehand')
|
||||
return self._Booster
|
||||
|
||||
Reference in New Issue
Block a user