Update document for sklearn model IO. (#6809)
* Update the use of JSON. * Remove unnecessary type cast.
This commit is contained in:
parent
905fdd3e08
commit
a5c852660b
@ -1949,7 +1949,7 @@ class Booster(object):
|
||||
ctypes.byref(cptr)))
|
||||
return ctypes2buffer(cptr, length.value)
|
||||
|
||||
def load_model(self, fname):
|
||||
def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None:
|
||||
"""Load the model from a file or bytearray. Path to file can be local
|
||||
or as an URI.
|
||||
|
||||
@ -1964,11 +1964,11 @@ class Booster(object):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname : string, os.PathLike, or a memory buffer
|
||||
fname :
|
||||
Input file name or memory buffer(see also save_raw)
|
||||
|
||||
"""
|
||||
if isinstance(fname, (STRING_TYPES, os.PathLike)):
|
||||
if isinstance(fname, (str, os.PathLike)):
|
||||
# assume file name, cannot use os.path.exist to check, file can be
|
||||
# from URL.
|
||||
fname = os.fspath(os.path.expanduser(fname))
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
import copy
|
||||
import warnings
|
||||
import json
|
||||
import os
|
||||
from typing import Union, Optional, List, Dict, Callable, Tuple, Any
|
||||
import numpy as np
|
||||
from .core import Booster, DMatrix, XGBoostError
|
||||
@ -500,25 +501,7 @@ class XGBModel(XGBModelBase):
|
||||
)
|
||||
return self._estimator_type # pylint: disable=no-member
|
||||
|
||||
def save_model(self, fname: str):
|
||||
"""Save the model to a file.
|
||||
|
||||
The model is saved in an XGBoost internal format which is universal
|
||||
among the various XGBoost interfaces. Auxiliary attributes of the
|
||||
Python Booster object (such as feature names) will not be saved.
|
||||
|
||||
.. note::
|
||||
|
||||
See:
|
||||
|
||||
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname : string
|
||||
Output file name
|
||||
|
||||
"""
|
||||
def save_model(self, fname: Union[str, os.PathLike]):
|
||||
meta = dict()
|
||||
for k, v in self.__dict__.items():
|
||||
if k == '_le':
|
||||
@ -542,27 +525,18 @@ class XGBModel(XGBModelBase):
|
||||
# Delete the attribute after save
|
||||
self.get_booster().set_attr(scikit_learn=None)
|
||||
|
||||
def load_model(self, fname):
|
||||
save_model.__doc__ = f"""{Booster.save_model.__doc__}"""
|
||||
|
||||
def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None:
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
"""Load the model from a file.
|
||||
|
||||
The model is loaded from an XGBoost internal format which is universal
|
||||
among the various XGBoost interfaces. Auxiliary attributes of the
|
||||
Python Booster object (such as feature names) will not be loaded.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname : string
|
||||
Input file name.
|
||||
|
||||
"""
|
||||
if not hasattr(self, '_Booster'):
|
||||
self._Booster = Booster({'n_jobs': self.n_jobs})
|
||||
self._Booster.load_model(fname)
|
||||
meta = self._Booster.attr('scikit_learn')
|
||||
self.get_booster().load_model(fname)
|
||||
meta = self.get_booster().attr('scikit_learn')
|
||||
if meta is None:
|
||||
warnings.warn(
|
||||
'Loading a native XGBoost model with Scikit-Learn interface.')
|
||||
'Loading a native XGBoost model with Scikit-Learn interface.'
|
||||
)
|
||||
return
|
||||
meta = json.loads(meta)
|
||||
states = dict()
|
||||
@ -574,9 +548,6 @@ class XGBModel(XGBModelBase):
|
||||
if k == 'classes_':
|
||||
self.classes_ = np.array(v)
|
||||
continue
|
||||
if k == 'use_label_encoder':
|
||||
self.use_label_encoder = bool(v)
|
||||
continue
|
||||
if k == "_estimator_type":
|
||||
if self._get_type() != v:
|
||||
raise TypeError(
|
||||
@ -589,6 +560,8 @@ class XGBModel(XGBModelBase):
|
||||
# Delete the attribute after load
|
||||
self.get_booster().set_attr(scikit_learn=None)
|
||||
|
||||
load_model.__doc__ = f"""{Booster.load_model.__doc__}"""
|
||||
|
||||
def _configure_fit(
|
||||
self,
|
||||
booster: Optional[Union[Booster, "XGBModel"]],
|
||||
|
||||
@ -804,10 +804,14 @@ def save_load_model(model_path):
|
||||
for train_index, test_index in kf.split(X, y):
|
||||
xgb_model = xgb.XGBClassifier(use_label_encoder=False).fit(X[train_index], y[train_index])
|
||||
xgb_model.save_model(model_path)
|
||||
xgb_model = xgb.XGBClassifier(use_label_encoder=False)
|
||||
|
||||
xgb_model = xgb.XGBClassifier()
|
||||
xgb_model.load_model(model_path)
|
||||
|
||||
assert xgb_model.use_label_encoder is False
|
||||
assert isinstance(xgb_model.classes_, np.ndarray)
|
||||
assert isinstance(xgb_model._Booster, xgb.Booster)
|
||||
|
||||
preds = xgb_model.predict(X[test_index])
|
||||
labels = y[test_index]
|
||||
err = sum(1 for i in range(len(preds))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user