Implement __sklearn_is_fitted__. (#7230)
This commit is contained in:
parent
d997c967d5
commit
037dd0820d
@ -435,6 +435,9 @@ class XGBModel(XGBModelBase):
|
||||
'''Tags used for scikit-learn data validation.'''
|
||||
return {'allow_nan': True, 'no_validation': True}
|
||||
|
||||
def __sklearn_is_fitted__(self) -> bool:
|
||||
return hasattr(self, "_Booster")
|
||||
|
||||
def get_booster(self) -> Booster:
|
||||
"""Get the underlying xgboost Booster of this model.
|
||||
|
||||
@ -444,7 +447,7 @@ class XGBModel(XGBModelBase):
|
||||
-------
|
||||
booster : a xgboost booster of underlying model
|
||||
"""
|
||||
if not hasattr(self, '_Booster'):
|
||||
if not self.__sklearn_is_fitted__():
|
||||
from sklearn.exceptions import NotFittedError
|
||||
raise NotFittedError('need to call fit or load_model beforehand')
|
||||
return self._Booster
|
||||
|
||||
@ -19,7 +19,7 @@ def test_gpu_binary_classification():
|
||||
from sklearn.datasets import load_digits
|
||||
from sklearn.model_selection import KFold
|
||||
|
||||
digits = load_digits(2)
|
||||
digits = load_digits(n_class=2)
|
||||
y = digits['target']
|
||||
X = digits['data']
|
||||
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user