Update document for sklearn model IO. (#6809)

* Update the use of JSON.
* Remove unnecessary type cast.
This commit is contained in:
Jiaming Yuan 2021-04-01 15:52:36 +08:00 committed by GitHub
parent 905fdd3e08
commit a5c852660b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 42 deletions

View File

@ -1949,7 +1949,7 @@ class Booster(object):
ctypes.byref(cptr)))
return ctypes2buffer(cptr, length.value)
def load_model(self, fname):
def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None:
"""Load the model from a file or bytearray. Path to file can be local
or as an URI.
@ -1964,11 +1964,11 @@ class Booster(object):
Parameters
----------
fname : string, os.PathLike, or a memory buffer
fname :
Input file name or memory buffer(see also save_raw)
"""
if isinstance(fname, (STRING_TYPES, os.PathLike)):
if isinstance(fname, (str, os.PathLike)):
# assume file name, cannot use os.path.exist to check, file can be
# from URL.
fname = os.fspath(os.path.expanduser(fname))

View File

@ -4,6 +4,7 @@
import copy
import warnings
import json
import os
from typing import Union, Optional, List, Dict, Callable, Tuple, Any
import numpy as np
from .core import Booster, DMatrix, XGBoostError
@ -500,25 +501,7 @@ class XGBModel(XGBModelBase):
)
return self._estimator_type # pylint: disable=no-member
def save_model(self, fname: str):
"""Save the model to a file.
The model is saved in an XGBoost internal format which is universal
among the various XGBoost interfaces. Auxiliary attributes of the
Python Booster object (such as feature names) will not be saved.
.. note::
See:
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
Parameters
----------
fname : string
Output file name
"""
def save_model(self, fname: Union[str, os.PathLike]):
meta = dict()
for k, v in self.__dict__.items():
if k == '_le':
@ -542,27 +525,18 @@ class XGBModel(XGBModelBase):
# Delete the attribute after save
self.get_booster().set_attr(scikit_learn=None)
def load_model(self, fname):
save_model.__doc__ = f"""{Booster.save_model.__doc__}"""
def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None:
# pylint: disable=attribute-defined-outside-init
"""Load the model from a file.
The model is loaded from an XGBoost internal format which is universal
among the various XGBoost interfaces. Auxiliary attributes of the
Python Booster object (such as feature names) will not be loaded.
Parameters
----------
fname : string
Input file name.
"""
if not hasattr(self, '_Booster'):
self._Booster = Booster({'n_jobs': self.n_jobs})
self._Booster.load_model(fname)
meta = self._Booster.attr('scikit_learn')
self.get_booster().load_model(fname)
meta = self.get_booster().attr('scikit_learn')
if meta is None:
warnings.warn(
'Loading a native XGBoost model with Scikit-Learn interface.')
'Loading a native XGBoost model with Scikit-Learn interface.'
)
return
meta = json.loads(meta)
states = dict()
@ -574,9 +548,6 @@ class XGBModel(XGBModelBase):
if k == 'classes_':
self.classes_ = np.array(v)
continue
if k == 'use_label_encoder':
self.use_label_encoder = bool(v)
continue
if k == "_estimator_type":
if self._get_type() != v:
raise TypeError(
@ -589,6 +560,8 @@ class XGBModel(XGBModelBase):
# Delete the attribute after load
self.get_booster().set_attr(scikit_learn=None)
load_model.__doc__ = f"""{Booster.load_model.__doc__}"""
def _configure_fit(
self,
booster: Optional[Union[Booster, "XGBModel"]],

View File

@ -804,10 +804,14 @@ def save_load_model(model_path):
for train_index, test_index in kf.split(X, y):
xgb_model = xgb.XGBClassifier(use_label_encoder=False).fit(X[train_index], y[train_index])
xgb_model.save_model(model_path)
xgb_model = xgb.XGBClassifier(use_label_encoder=False)
xgb_model = xgb.XGBClassifier()
xgb_model.load_model(model_path)
assert xgb_model.use_label_encoder is False
assert isinstance(xgb_model.classes_, np.ndarray)
assert isinstance(xgb_model._Booster, xgb.Booster)
preds = xgb_model.predict(X[test_index])
labels = y[test_index]
err = sum(1 for i in range(len(preds))