Implement __sklearn_is_fitted__. (#7230)

This commit is contained in:
Jiaming Yuan
2021-09-15 19:09:04 +08:00
committed by GitHub
parent d997c967d5
commit 037dd0820d
2 changed files with 5 additions and 2 deletions

View File

@@ -435,6 +435,9 @@ class XGBModel(XGBModelBase):
'''Tags used for scikit-learn data validation.'''
return {'allow_nan': True, 'no_validation': True}
def __sklearn_is_fitted__(self) -> bool:
return hasattr(self, "_Booster")
def get_booster(self) -> Booster:
"""Get the underlying xgboost Booster of this model.
@@ -444,7 +447,7 @@ class XGBModel(XGBModelBase):
-------
booster : a xgboost booster of underlying model
"""
if not hasattr(self, '_Booster'):
if not self.__sklearn_is_fitted__():
from sklearn.exceptions import NotFittedError
raise NotFittedError('need to call fit or load_model beforehand')
return self._Booster