Implement __sklearn_is_fitted__. (#7230)

This commit is contained in:
Jiaming Yuan 2021-09-15 19:09:04 +08:00 committed by GitHub
parent d997c967d5
commit 037dd0820d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 2 deletions

View File

@ -435,6 +435,9 @@ class XGBModel(XGBModelBase):
'''Tags used for scikit-learn data validation.'''
return {'allow_nan': True, 'no_validation': True}
def __sklearn_is_fitted__(self) -> bool:
return hasattr(self, "_Booster")
def get_booster(self) -> Booster:
"""Get the underlying xgboost Booster of this model.
@ -444,7 +447,7 @@ class XGBModel(XGBModelBase):
-------
booster : a xgboost booster of underlying model
"""
if not hasattr(self, '_Booster'):
if not self.__sklearn_is_fitted__():
from sklearn.exceptions import NotFittedError
raise NotFittedError('need to call fit or load_model beforehand')
return self._Booster

View File

@ -19,7 +19,7 @@ def test_gpu_binary_classification():
from sklearn.datasets import load_digits
from sklearn.model_selection import KFold
digits = load_digits(2)
digits = load_digits(n_class=2)
y = digits['target']
X = digits['data']
kf = KFold(n_splits=2, shuffle=True, random_state=rng)