Implement __sklearn_is_fitted__. (#7230)
This commit is contained in:
parent
d997c967d5
commit
037dd0820d
@ -435,6 +435,9 @@ class XGBModel(XGBModelBase):
|
|||||||
'''Tags used for scikit-learn data validation.'''
|
'''Tags used for scikit-learn data validation.'''
|
||||||
return {'allow_nan': True, 'no_validation': True}
|
return {'allow_nan': True, 'no_validation': True}
|
||||||
|
|
||||||
|
def __sklearn_is_fitted__(self) -> bool:
|
||||||
|
return hasattr(self, "_Booster")
|
||||||
|
|
||||||
def get_booster(self) -> Booster:
|
def get_booster(self) -> Booster:
|
||||||
"""Get the underlying xgboost Booster of this model.
|
"""Get the underlying xgboost Booster of this model.
|
||||||
|
|
||||||
@ -444,7 +447,7 @@ class XGBModel(XGBModelBase):
|
|||||||
-------
|
-------
|
||||||
booster : a xgboost booster of underlying model
|
booster : a xgboost booster of underlying model
|
||||||
"""
|
"""
|
||||||
if not hasattr(self, '_Booster'):
|
if not self.__sklearn_is_fitted__():
|
||||||
from sklearn.exceptions import NotFittedError
|
from sklearn.exceptions import NotFittedError
|
||||||
raise NotFittedError('need to call fit or load_model beforehand')
|
raise NotFittedError('need to call fit or load_model beforehand')
|
||||||
return self._Booster
|
return self._Booster
|
||||||
|
|||||||
@ -19,7 +19,7 @@ def test_gpu_binary_classification():
|
|||||||
from sklearn.datasets import load_digits
|
from sklearn.datasets import load_digits
|
||||||
from sklearn.model_selection import KFold
|
from sklearn.model_selection import KFold
|
||||||
|
|
||||||
digits = load_digits(2)
|
digits = load_digits(n_class=2)
|
||||||
y = digits['target']
|
y = digits['target']
|
||||||
X = digits['data']
|
X = digits['data']
|
||||||
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user