From a5c852660b1056204aa2e0cbfcd5b4ecfbf31adf Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 1 Apr 2021 15:52:36 +0800 Subject: [PATCH] Update document for sklearn model IO. (#6809) * Update the use of JSON. * Remove unnecessary type cast. --- python-package/xgboost/core.py | 6 ++-- python-package/xgboost/sklearn.py | 49 +++++++------------------------ tests/python/test_with_sklearn.py | 6 +++- 3 files changed, 19 insertions(+), 42 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 6a9a1dd48..56a6e10b4 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1949,7 +1949,7 @@ class Booster(object): ctypes.byref(cptr))) return ctypes2buffer(cptr, length.value) - def load_model(self, fname): + def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None: """Load the model from a file or bytearray. Path to file can be local or as an URI. @@ -1964,11 +1964,11 @@ class Booster(object): Parameters ---------- - fname : string, os.PathLike, or a memory buffer + fname : Input file name or memory buffer(see also save_raw) """ - if isinstance(fname, (STRING_TYPES, os.PathLike)): + if isinstance(fname, (str, os.PathLike)): # assume file name, cannot use os.path.exist to check, file can be # from URL. fname = os.fspath(os.path.expanduser(fname)) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 10a115027..6aefc3d0c 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -4,6 +4,7 @@ import copy import warnings import json +import os from typing import Union, Optional, List, Dict, Callable, Tuple, Any import numpy as np from .core import Booster, DMatrix, XGBoostError @@ -500,25 +501,7 @@ class XGBModel(XGBModelBase): ) return self._estimator_type # pylint: disable=no-member - def save_model(self, fname: str): - """Save the model to a file. - - The model is saved in an XGBoost internal format which is universal - among the various XGBoost interfaces. Auxiliary attributes of the - Python Booster object (such as feature names) will not be saved. - - .. note:: - - See: - - https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html - - Parameters - ---------- - fname : string - Output file name - - """ + def save_model(self, fname: Union[str, os.PathLike]): meta = dict() for k, v in self.__dict__.items(): if k == '_le': @@ -542,27 +525,18 @@ class XGBModel(XGBModelBase): # Delete the attribute after save self.get_booster().set_attr(scikit_learn=None) - def load_model(self, fname): + save_model.__doc__ = f"""{Booster.save_model.__doc__}""" + + def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None: # pylint: disable=attribute-defined-outside-init - """Load the model from a file. - - The model is loaded from an XGBoost internal format which is universal - among the various XGBoost interfaces. Auxiliary attributes of the - Python Booster object (such as feature names) will not be loaded. - - Parameters - ---------- - fname : string - Input file name. - - """ if not hasattr(self, '_Booster'): self._Booster = Booster({'n_jobs': self.n_jobs}) - self._Booster.load_model(fname) - meta = self._Booster.attr('scikit_learn') + self.get_booster().load_model(fname) + meta = self.get_booster().attr('scikit_learn') if meta is None: warnings.warn( - 'Loading a native XGBoost model with Scikit-Learn interface.') + 'Loading a native XGBoost model with Scikit-Learn interface.' + ) return meta = json.loads(meta) states = dict() @@ -574,9 +548,6 @@ class XGBModel(XGBModelBase): if k == 'classes_': self.classes_ = np.array(v) continue - if k == 'use_label_encoder': - self.use_label_encoder = bool(v) - continue if k == "_estimator_type": if self._get_type() != v: raise TypeError( @@ -589,6 +560,8 @@ class XGBModel(XGBModelBase): # Delete the attribute after load self.get_booster().set_attr(scikit_learn=None) + load_model.__doc__ = f"""{Booster.load_model.__doc__}""" + def _configure_fit( self, booster: Optional[Union[Booster, "XGBModel"]], diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index c5081fd83..12297ceec 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -804,10 +804,14 @@ def save_load_model(model_path): for train_index, test_index in kf.split(X, y): xgb_model = xgb.XGBClassifier(use_label_encoder=False).fit(X[train_index], y[train_index]) xgb_model.save_model(model_path) - xgb_model = xgb.XGBClassifier(use_label_encoder=False) + + xgb_model = xgb.XGBClassifier() xgb_model.load_model(model_path) + + assert xgb_model.use_label_encoder is False assert isinstance(xgb_model.classes_, np.ndarray) assert isinstance(xgb_model._Booster, xgb.Booster) + preds = xgb_model.predict(X[test_index]) labels = y[test_index] err = sum(1 for i in range(len(preds))