From c2b3a13e709c1f726cef1ebfd987a98809aafe48 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 27 Mar 2023 21:34:10 +0800 Subject: [PATCH] [breaking][skl] Remove parameter serialization. (#8963) - Remove parameter serialization in the scikit-learn interface. The scikit-lear interface `save_model` will save only the model and discard all hyper-parameters. This is to align with the native XGBoost interface, which distinguishes the hyper-parameter and model parameters. With the scikit-learn interface, model parameters are attributes of the estimator. For instance, `n_features_in_`, `n_classes_` are always accessible with `estimator.n_features_in_` and `estimator.n_classes_`, but not with the `estimator.get_params`. - Define a `load_model` method for classifier to load its own attributes. - Set n_estimators to None by default. --- python-package/xgboost/_typing.py | 2 + python-package/xgboost/compat.py | 25 ---- python-package/xgboost/core.py | 3 +- python-package/xgboost/dask.py | 9 +- python-package/xgboost/libpath.py | 2 +- python-package/xgboost/sklearn.py | 183 ++++++++++------------- python-package/xgboost/spark/core.py | 2 + tests/python/test_model_compatibility.py | 1 - tests/python/test_with_sklearn.py | 67 ++++++--- 9 files changed, 134 insertions(+), 160 deletions(-) diff --git a/python-package/xgboost/_typing.py b/python-package/xgboost/_typing.py index 0adad9478..774681031 100644 --- a/python-package/xgboost/_typing.py +++ b/python-package/xgboost/_typing.py @@ -43,6 +43,8 @@ FPreProcCallable = Callable # c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h c_bst_ulong = ctypes.c_uint64 # pylint: disable=C0103 +ModelIn = Union[str, bytearray, os.PathLike] + CTypeT = TypeVar( "CTypeT", ctypes.c_void_p, diff --git a/python-package/xgboost/compat.py b/python-package/xgboost/compat.py index 3be023abf..a01eeef09 100644 --- a/python-package/xgboost/compat.py +++ b/python-package/xgboost/compat.py @@ -88,31 +88,6 @@ def is_cudf_available() -> bool: return False -class XGBoostLabelEncoder(LabelEncoder): - """Label encoder with JSON serialization methods.""" - - def to_json(self) -> Dict: - """Returns a JSON compatible dictionary""" - meta = {} - for k, v in self.__dict__.items(): - if isinstance(v, np.ndarray): - meta[k] = v.tolist() - else: - meta[k] = v - return meta - - def from_json(self, doc: Dict) -> None: - # pylint: disable=attribute-defined-outside-init - """Load the encoder back from a JSON compatible dict.""" - meta = {} - for k, v in doc.items(): - if k == "classes_": - self.classes_ = np.array(v) - continue - meta[k] = v - self.__dict__.update(meta) - - try: import scipy.sparse as scipy_sparse from scipy.sparse import csr_matrix as scipy_csr diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 5a0cfb3a2..a0393391e 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -47,6 +47,7 @@ from ._typing import ( FeatureInfo, FeatureNames, FeatureTypes, + ModelIn, NumpyOrCupy, c_bst_ulong, ) @@ -2477,7 +2478,7 @@ class Booster: ) return ctypes2buffer(cptr, length.value) - def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None: + def load_model(self, fname: ModelIn) -> None: """Load the model from a file or bytearray. Path to file can be local or as an URI. diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 0e5e0d28e..a17fbad70 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -60,7 +60,7 @@ from typing import ( import numpy from . import collective, config -from ._typing import _T, FeatureNames, FeatureTypes +from ._typing import _T, FeatureNames, FeatureTypes, ModelIn from .callback import TrainingCallback from .compat import DataFrame, LazyLoader, concat, lazy_isinstance from .core import ( @@ -76,6 +76,7 @@ from .core import ( from .sklearn import ( XGBClassifier, XGBClassifierBase, + XGBClassifierMixIn, XGBModel, XGBRanker, XGBRankerMixIn, @@ -1839,7 +1840,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase): "Implementation of the scikit-learn API for XGBoost classification.", ["estimators", "model"], ) -class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): +class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierMixIn, XGBClassifierBase): # pylint: disable=missing-class-docstring async def _fit_async( self, @@ -2019,6 +2020,10 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): preds = da.map_blocks(_argmax, pred_probs, drop_axis=1) return preds + def load_model(self, fname: ModelIn) -> None: + super().load_model(fname) + self._load_model_attributes(self.get_booster()) + @xgboost_model_doc( """Implementation of the Scikit-Learn API for XGBoost Ranking. diff --git a/python-package/xgboost/libpath.py b/python-package/xgboost/libpath.py index 9223acaa5..be37b364e 100644 --- a/python-package/xgboost/libpath.py +++ b/python-package/xgboost/libpath.py @@ -55,7 +55,7 @@ def find_lib_path() -> List[str]: # XGBOOST_BUILD_DOC is defined by sphinx conf. if not lib_path and not os.environ.get("XGBOOST_BUILD_DOC", False): - link = "https://xgboost.readthedocs.io/en/latest/build.html" + link = "https://xgboost.readthedocs.io/en/stable/install.html" msg = ( "Cannot find XGBoost Library in the candidate path. " + "List of candidates:\n- " diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 52175981a..563ff8659 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -22,23 +22,18 @@ from typing import ( import numpy as np from scipy.special import softmax -from ._typing import ArrayLike, FeatureNames, FeatureTypes +from ._typing import ArrayLike, FeatureNames, FeatureTypes, ModelIn from .callback import TrainingCallback # Do not use class names on scikit-learn directly. Re-define the classes on # .compat to guarantee the behavior without scikit-learn -from .compat import ( - SKLEARN_INSTALLED, - XGBClassifierBase, - XGBModelBase, - XGBoostLabelEncoder, - XGBRegressorBase, -) +from .compat import SKLEARN_INSTALLED, XGBClassifierBase, XGBModelBase, XGBRegressorBase from .config import config_context from .core import ( Booster, DMatrix, Metric, + Objective, QuantileDMatrix, XGBoostError, _convert_ntree_limit, @@ -49,9 +44,24 @@ from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array, _is_pandas_df from .training import train +class XGBClassifierMixIn: # pylint: disable=too-few-public-methods + """MixIn for classification.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + def _load_model_attributes(self, booster: Booster) -> None: + config = json.loads(booster.save_config()) + self.n_classes_ = int(config["learner"]["learner_model_param"]["num_class"]) + # binary classification is treated as regression in XGBoost. + self.n_classes_ = 2 if self.n_classes_ < 2 else self.n_classes_ + + class XGBRankerMixIn: # pylint: disable=too-few-public-methods - """MixIn for ranking, defines the _estimator_type usually defined in scikit-learn base - classes.""" + """MixIn for ranking, defines the _estimator_type usually defined in scikit-learn + base classes. + + """ _estimator_type = "ranker" @@ -74,7 +84,7 @@ SklObjective = Optional[ def _objective_decorator( func: Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]] -) -> Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]: +) -> Objective: """Decorate an objective function Converts an objective function using the typical sklearn metrics @@ -173,7 +183,7 @@ def ltr_metric_decorator(func: Callable, n_jobs: Optional[int]) -> Metric: __estimator_doc = """ - n_estimators : int + n_estimators : Optional[int] Number of gradient boosted trees. Equivalent to number of boosting rounds. """ @@ -598,6 +608,9 @@ def _wrap_evaluation_matrices( return train_dmatrix, evals +DEFAULT_N_ESTIMATORS = 100 + + @xgboost_model_doc( """Implementation of the Scikit-Learn API for XGBoost.""", ["estimators", "model", "objective"], @@ -611,7 +624,7 @@ class XGBModel(XGBModelBase): max_bin: Optional[int] = None, grow_policy: Optional[str] = None, learning_rate: Optional[float] = None, - n_estimators: int = 100, + n_estimators: Optional[int] = None, verbosity: Optional[int] = None, objective: SklObjective = None, booster: Optional[str] = None, @@ -797,7 +810,7 @@ class XGBModel(XGBModelBase): def get_num_boosting_rounds(self) -> int: """Gets the number of xgboost boosting rounds.""" - return self.n_estimators + return DEFAULT_N_ESTIMATORS if self.n_estimators is None else self.n_estimators def _get_type(self) -> str: if not hasattr(self, "_estimator_type"): @@ -809,72 +822,33 @@ class XGBModel(XGBModelBase): def save_model(self, fname: Union[str, os.PathLike]) -> None: meta: Dict[str, Any] = {} - for k, v in self.__dict__.items(): - if k == "_le": - meta["_le"] = self._le.to_json() - continue - if k == "_Booster": - continue - if k == "classes_": - # numpy array is not JSON serializable - meta["classes_"] = self.classes_.tolist() - continue - if k == "feature_types": - # Use the `feature_types` attribute from booster instead. - meta["feature_types"] = None - continue - try: - json.dumps({k: v}) - meta[k] = v - except TypeError: - warnings.warn( - str(k) + " is not saved in Scikit-Learn meta.", UserWarning - ) + # For validation. meta["_estimator_type"] = self._get_type() meta_str = json.dumps(meta) self.get_booster().set_attr(scikit_learn=meta_str) self.get_booster().save_model(fname) - # Delete the attribute after save self.get_booster().set_attr(scikit_learn=None) save_model.__doc__ = f"""{Booster.save_model.__doc__}""" - def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None: + def load_model(self, fname: ModelIn) -> None: # pylint: disable=attribute-defined-outside-init - if not hasattr(self, "_Booster"): + if not self.__sklearn_is_fitted__(): self._Booster = Booster({"n_jobs": self.n_jobs}) self.get_booster().load_model(fname) + meta_str = self.get_booster().attr("scikit_learn") if meta_str is None: - # FIXME(jiaming): This doesn't have to be a problem as most of the needed - # information like num_class and objective is in Learner class. - warnings.warn("Loading a native XGBoost model with Scikit-Learn interface.") return + meta = json.loads(meta_str) - states = {} - for k, v in meta.items(): - if k == "_le": - self._le = XGBoostLabelEncoder() - self._le.from_json(v) - continue - # FIXME(jiaming): This can be removed once label encoder is gone since we can - # generate it from `np.arange(self.n_classes_)` - if k == "classes_": - self.classes_ = np.array(v) - continue - if k == "feature_types": - self.feature_types = self.get_booster().feature_types - continue - if k == "_estimator_type": - if self._get_type() != v: - raise TypeError( - "Loading an estimator with different type. " - f"Expecting: {self._get_type()}, got: {v}" - ) - continue - states[k] = v - self.__dict__.update(states) - # Delete the attribute after load + t = meta.get("_estimator_type", None) + if t is not None and t != self._get_type(): + raise TypeError( + "Loading an estimator with different type. Expecting: " + f"{self._get_type()}, got: {t}" + ) + self.feature_types = self.get_booster().feature_types self.get_booster().set_attr(scikit_learn=None) load_model.__doc__ = f"""{Booster.load_model.__doc__}""" @@ -965,7 +939,6 @@ class XGBModel(XGBModelBase): "Experimental support for categorical data is not implemented for" " current tree method yet." ) - return model, metric, params, early_stopping_rounds, callbacks def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix: @@ -1086,9 +1059,7 @@ class XGBModel(XGBModelBase): params = self.get_xgb_params() if callable(self.objective): - obj: Optional[ - Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]] - ] = _objective_decorator(self.objective) + obj: Optional[Objective] = _objective_decorator(self.objective) params["objective"] = "reg:squarederror" else: obj = None @@ -1304,8 +1275,10 @@ class XGBModel(XGBModelBase): @property def feature_names_in_(self) -> np.ndarray: - """Names of features seen during :py:meth:`fit`. Defined only when `X` has feature - names that are all strings.""" + """Names of features seen during :py:meth:`fit`. Defined only when `X` has + feature names that are all strings. + + """ feature_names = self.get_booster().feature_names if feature_names is None: raise AttributeError( @@ -1453,26 +1426,19 @@ def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) -> "Implementation of the scikit-learn API for XGBoost classification.", ["model", "objective"], extra_parameters=""" - n_estimators : int + n_estimators : Optional[int] Number of boosting rounds. """, ) -class XGBClassifier(XGBModel, XGBClassifierBase): +class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase): # pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes @_deprecate_positional_args def __init__( self, *, objective: SklObjective = "binary:logistic", - use_label_encoder: Optional[bool] = None, **kwargs: Any, ) -> None: - # must match the parameters for `get_params` - self.use_label_encoder = use_label_encoder - if use_label_encoder is True: - raise ValueError("Label encoder was removed in 1.6.0.") - if use_label_encoder is not None: - warnings.warn("`use_label_encoder` is deprecated in 1.7.0.") super().__init__(objective=objective, **kwargs) @_deprecate_positional_args @@ -1496,38 +1462,38 @@ class XGBClassifier(XGBModel, XGBClassifierBase): # pylint: disable = attribute-defined-outside-init,too-many-statements with config_context(verbosity=self.verbosity): evals_result: TrainingCallback.EvalsLog = {} - + # We keep the n_classes_ as a simple member instead of loading it from + # booster in a Python property. This way we can have efficient and + # thread-safe prediction. if _is_cudf_df(y) or _is_cudf_ser(y): import cupy as cp # pylint: disable=E0401 - self.classes_ = cp.unique(y.values) - self.n_classes_ = len(self.classes_) - expected_classes = cp.arange(self.n_classes_) + classes = cp.unique(y.values) + self.n_classes_ = len(classes) + expected_classes = cp.array(self.classes_) elif _is_cupy_array(y): import cupy as cp # pylint: disable=E0401 - self.classes_ = cp.unique(y) - self.n_classes_ = len(self.classes_) - expected_classes = cp.arange(self.n_classes_) + classes = cp.unique(y) + self.n_classes_ = len(classes) + expected_classes = cp.array(self.classes_) else: - self.classes_ = np.unique(np.asarray(y)) - self.n_classes_ = len(self.classes_) - expected_classes = np.arange(self.n_classes_) + classes = np.unique(np.asarray(y)) + self.n_classes_ = len(classes) + expected_classes = self.classes_ if ( - self.classes_.shape != expected_classes.shape - or not (self.classes_ == expected_classes).all() + classes.shape != expected_classes.shape + or not (classes == expected_classes).all() ): raise ValueError( f"Invalid classes inferred from unique values of `y`. " - f"Expected: {expected_classes}, got {self.classes_}" + f"Expected: {expected_classes}, got {classes}" ) params = self.get_xgb_params() if callable(self.objective): - obj: Optional[ - Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]] - ] = _objective_decorator(self.objective) + obj: Optional[Objective] = _objective_decorator(self.objective) # Use default value. Is it really not used ? params["objective"] = "binary:logistic" else: @@ -1616,7 +1582,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase): if len(class_probs.shape) > 1 and self.n_classes_ != 2: # multi-class, turns softprob into softmax - column_indexes: np.ndarray = np.argmax(class_probs, axis=1) # type: ignore + column_indexes: np.ndarray = np.argmax(class_probs, axis=1) elif len(class_probs.shape) > 1 and class_probs.shape[1] != 1: # multi-label column_indexes = np.zeros(class_probs.shape) @@ -1628,8 +1594,6 @@ class XGBClassifier(XGBModel, XGBClassifierBase): column_indexes = np.repeat(0, class_probs.shape[0]) column_indexes[class_probs > 0.5] = 1 - if hasattr(self, "_le"): - return self._le.inverse_transform(column_indexes) return column_indexes def predict_proba( @@ -1693,17 +1657,22 @@ class XGBClassifier(XGBModel, XGBClassifierBase): base_margin=base_margin, iteration_range=iteration_range, ) - # If model is loaded from a raw booster there's no `n_classes_` - return _cls_predict_proba( - getattr(self, "n_classes_", 0), class_probs, np.vstack - ) + return _cls_predict_proba(self.n_classes_, class_probs, np.vstack) + + @property + def classes_(self) -> np.ndarray: + return np.arange(self.n_classes_) + + def load_model(self, fname: ModelIn) -> None: + super().load_model(fname) + self._load_model_attributes(self.get_booster()) @xgboost_model_doc( "scikit-learn API for XGBoost random forest classification.", ["model", "objective"], extra_parameters=""" - n_estimators : int + n_estimators : Optional[int] Number of trees in random forest to fit. """, ) @@ -1730,7 +1699,7 @@ class XGBRFClassifier(XGBClassifier): def get_xgb_params(self) -> Dict[str, Any]: params = super().get_xgb_params() - params["num_parallel_tree"] = self.n_estimators + params["num_parallel_tree"] = super().get_num_boosting_rounds() return params def get_num_boosting_rounds(self) -> int: @@ -1778,7 +1747,7 @@ class XGBRegressor(XGBModel, XGBRegressorBase): "scikit-learn API for XGBoost random forest regression.", ["model", "objective"], extra_parameters=""" - n_estimators : int + n_estimators : Optional[int] Number of trees in random forest to fit. """, ) @@ -1805,7 +1774,7 @@ class XGBRFRegressor(XGBRegressor): def get_xgb_params(self) -> Dict[str, Any]: params = super().get_xgb_params() - params["num_parallel_tree"] = self.n_estimators + params["num_parallel_tree"] = super().get_num_boosting_rounds() return params def get_num_boosting_rounds(self) -> int: diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 745c9348f..1a614f51f 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -39,6 +39,7 @@ import xgboost from xgboost import XGBClassifier, XGBRanker, XGBRegressor from xgboost.compat import is_cudf_available from xgboost.core import Booster +from xgboost.sklearn import DEFAULT_N_ESTIMATORS from xgboost.training import train as worker_train from .data import ( @@ -215,6 +216,7 @@ class _SparkXGBParams( filtered_params_dict = { k: params_dict[k] for k in params_dict if k not in _unsupported_xgb_params } + filtered_params_dict["n_estimators"] = DEFAULT_N_ESTIMATORS return filtered_params_dict def _set_xgb_params_default(self): diff --git a/tests/python/test_model_compatibility.py b/tests/python/test_model_compatibility.py index a46715e42..c9b7646ef 100644 --- a/tests/python/test_model_compatibility.py +++ b/tests/python/test_model_compatibility.py @@ -66,7 +66,6 @@ def run_scikit_model_check(name, path): cls.load_model(path) if name.find('0.90') == -1: assert len(cls.classes_) == gm.kClasses - assert len(cls._le.classes_) == gm.kClasses assert cls.n_classes_ == gm.kClasses assert (len(cls.get_booster().get_dump()) == gm.kRounds * gm.kForests * gm.kClasses), path diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index c34b7d2d1..90d4dff18 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -38,36 +38,34 @@ def test_binary_classification(): assert err < 0.1 -@pytest.mark.parametrize('objective', ['multi:softmax', 'multi:softprob']) +@pytest.mark.parametrize("objective", ["multi:softmax", "multi:softprob"]) def test_multiclass_classification(objective): from sklearn.datasets import load_iris from sklearn.model_selection import KFold def check_pred(preds, labels, output_margin): if output_margin: - err = sum(1 for i in range(len(preds)) - if preds[i].argmax() != labels[i]) / float(len(preds)) + err = sum( + 1 for i in range(len(preds)) if preds[i].argmax() != labels[i] + ) / float(len(preds)) else: - err = sum(1 for i in range(len(preds)) - if preds[i] != labels[i]) / float(len(preds)) + err = sum(1 for i in range(len(preds)) if preds[i] != labels[i]) / float( + len(preds) + ) assert err < 0.4 - iris = load_iris() - y = iris['target'] - X = iris['data'] + X, y = load_iris(return_X_y=True) kf = KFold(n_splits=2, shuffle=True, random_state=rng) for train_index, test_index in kf.split(X, y): - xgb_model = xgb.XGBClassifier(objective=objective).fit(X[train_index], y[train_index]) - assert (xgb_model.get_booster().num_boosted_rounds() == - xgb_model.n_estimators) + xgb_model = xgb.XGBClassifier(objective=objective).fit( + X[train_index], y[train_index] + ) + assert xgb_model.get_booster().num_boosted_rounds() == 100 preds = xgb_model.predict(X[test_index]) # test other params in XGBClassifier().fit - preds2 = xgb_model.predict(X[test_index], output_margin=True, - ntree_limit=3) - preds3 = xgb_model.predict(X[test_index], output_margin=True, - ntree_limit=0) - preds4 = xgb_model.predict(X[test_index], output_margin=False, - ntree_limit=3) + preds2 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=3) + preds3 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=0) + preds4 = xgb_model.predict(X[test_index], output_margin=False, ntree_limit=3) labels = y[test_index] check_pred(preds, labels, output_margin=False) @@ -761,9 +759,9 @@ def test_parameters_access(): clf = save_load(clf) assert clf.tree_method is None - assert clf.n_estimators == 2 + assert clf.n_estimators is None assert clf.get_params()["tree_method"] is None - assert clf.get_params()["n_estimators"] == 2 + assert clf.get_params()["n_estimators"] is None assert get_tm(clf) == "auto" # discarded for save/load_model clf.set_params(tree_method="hist") @@ -771,9 +769,7 @@ def test_parameters_access(): clf = pickle.loads(pickle.dumps(clf)) assert clf.get_params()["tree_method"] == "hist" clf = save_load(clf) - # FIXME(jiamingy): We should remove this behavior once we remove parameters - # serialization for skl save/load_model. - assert clf.get_params()["tree_method"] == "hist" + assert clf.get_params()["tree_method"] is None def test_kwargs_error(): @@ -902,6 +898,7 @@ def save_load_model(model_path): xgb_model.load_model(model_path) assert isinstance(xgb_model.classes_, np.ndarray) + np.testing.assert_equal(xgb_model.classes_, np.array([0, 1])) assert isinstance(xgb_model._Booster, xgb.Booster) preds = xgb_model.predict(X[test_index]) @@ -933,8 +930,10 @@ def test_save_load_model(): save_load_model(model_path) from sklearn.datasets import load_digits + from sklearn.model_selection import train_test_split + with tempfile.TemporaryDirectory() as tempdir: - model_path = os.path.join(tempdir, 'digits.model.json') + model_path = os.path.join(tempdir, 'digits.model.ubj') digits = load_digits(n_class=2) y = digits['target'] X = digits['data'] @@ -959,6 +958,28 @@ def test_save_load_model(): predt_1 = cls.predict(X) assert np.allclose(predt_0, predt_1) + # mclass + X, y = load_digits(n_class=10, return_X_y=True) + # small test_size to force early stop + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.01, random_state=1 + ) + clf = xgb.XGBClassifier( + n_estimators=64, tree_method="hist", early_stopping_rounds=2 + ) + clf.fit(X_train, y_train, eval_set=[(X_test, y_test)]) + score = clf.best_score + clf.save_model(model_path) + + clf = xgb.XGBClassifier() + clf.load_model(model_path) + assert clf.classes_.size == 10 + np.testing.assert_equal(clf.classes_, np.arange(10)) + assert clf.n_classes_ == 10 + + assert clf.best_iteration == 27 + assert clf.best_score == score + def test_RFECV(): from sklearn.datasets import load_breast_cancer, load_diabetes, load_iris