Update document for sklearn model IO. (#6809)
* Update the use of JSON. * Remove unnecessary type cast.
This commit is contained in:
parent
905fdd3e08
commit
a5c852660b
@ -1949,7 +1949,7 @@ class Booster(object):
|
|||||||
ctypes.byref(cptr)))
|
ctypes.byref(cptr)))
|
||||||
return ctypes2buffer(cptr, length.value)
|
return ctypes2buffer(cptr, length.value)
|
||||||
|
|
||||||
def load_model(self, fname):
|
def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None:
|
||||||
"""Load the model from a file or bytearray. Path to file can be local
|
"""Load the model from a file or bytearray. Path to file can be local
|
||||||
or as an URI.
|
or as an URI.
|
||||||
|
|
||||||
@ -1964,11 +1964,11 @@ class Booster(object):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
fname : string, os.PathLike, or a memory buffer
|
fname :
|
||||||
Input file name or memory buffer(see also save_raw)
|
Input file name or memory buffer(see also save_raw)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if isinstance(fname, (STRING_TYPES, os.PathLike)):
|
if isinstance(fname, (str, os.PathLike)):
|
||||||
# assume file name, cannot use os.path.exist to check, file can be
|
# assume file name, cannot use os.path.exist to check, file can be
|
||||||
# from URL.
|
# from URL.
|
||||||
fname = os.fspath(os.path.expanduser(fname))
|
fname = os.fspath(os.path.expanduser(fname))
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import warnings
|
import warnings
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from typing import Union, Optional, List, Dict, Callable, Tuple, Any
|
from typing import Union, Optional, List, Dict, Callable, Tuple, Any
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .core import Booster, DMatrix, XGBoostError
|
from .core import Booster, DMatrix, XGBoostError
|
||||||
@ -500,25 +501,7 @@ class XGBModel(XGBModelBase):
|
|||||||
)
|
)
|
||||||
return self._estimator_type # pylint: disable=no-member
|
return self._estimator_type # pylint: disable=no-member
|
||||||
|
|
||||||
def save_model(self, fname: str):
|
def save_model(self, fname: Union[str, os.PathLike]):
|
||||||
"""Save the model to a file.
|
|
||||||
|
|
||||||
The model is saved in an XGBoost internal format which is universal
|
|
||||||
among the various XGBoost interfaces. Auxiliary attributes of the
|
|
||||||
Python Booster object (such as feature names) will not be saved.
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
See:
|
|
||||||
|
|
||||||
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
fname : string
|
|
||||||
Output file name
|
|
||||||
|
|
||||||
"""
|
|
||||||
meta = dict()
|
meta = dict()
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
if k == '_le':
|
if k == '_le':
|
||||||
@ -542,27 +525,18 @@ class XGBModel(XGBModelBase):
|
|||||||
# Delete the attribute after save
|
# Delete the attribute after save
|
||||||
self.get_booster().set_attr(scikit_learn=None)
|
self.get_booster().set_attr(scikit_learn=None)
|
||||||
|
|
||||||
def load_model(self, fname):
|
save_model.__doc__ = f"""{Booster.save_model.__doc__}"""
|
||||||
|
|
||||||
|
def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None:
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
"""Load the model from a file.
|
|
||||||
|
|
||||||
The model is loaded from an XGBoost internal format which is universal
|
|
||||||
among the various XGBoost interfaces. Auxiliary attributes of the
|
|
||||||
Python Booster object (such as feature names) will not be loaded.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
fname : string
|
|
||||||
Input file name.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if not hasattr(self, '_Booster'):
|
if not hasattr(self, '_Booster'):
|
||||||
self._Booster = Booster({'n_jobs': self.n_jobs})
|
self._Booster = Booster({'n_jobs': self.n_jobs})
|
||||||
self._Booster.load_model(fname)
|
self.get_booster().load_model(fname)
|
||||||
meta = self._Booster.attr('scikit_learn')
|
meta = self.get_booster().attr('scikit_learn')
|
||||||
if meta is None:
|
if meta is None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'Loading a native XGBoost model with Scikit-Learn interface.')
|
'Loading a native XGBoost model with Scikit-Learn interface.'
|
||||||
|
)
|
||||||
return
|
return
|
||||||
meta = json.loads(meta)
|
meta = json.loads(meta)
|
||||||
states = dict()
|
states = dict()
|
||||||
@ -574,9 +548,6 @@ class XGBModel(XGBModelBase):
|
|||||||
if k == 'classes_':
|
if k == 'classes_':
|
||||||
self.classes_ = np.array(v)
|
self.classes_ = np.array(v)
|
||||||
continue
|
continue
|
||||||
if k == 'use_label_encoder':
|
|
||||||
self.use_label_encoder = bool(v)
|
|
||||||
continue
|
|
||||||
if k == "_estimator_type":
|
if k == "_estimator_type":
|
||||||
if self._get_type() != v:
|
if self._get_type() != v:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
@ -589,6 +560,8 @@ class XGBModel(XGBModelBase):
|
|||||||
# Delete the attribute after load
|
# Delete the attribute after load
|
||||||
self.get_booster().set_attr(scikit_learn=None)
|
self.get_booster().set_attr(scikit_learn=None)
|
||||||
|
|
||||||
|
load_model.__doc__ = f"""{Booster.load_model.__doc__}"""
|
||||||
|
|
||||||
def _configure_fit(
|
def _configure_fit(
|
||||||
self,
|
self,
|
||||||
booster: Optional[Union[Booster, "XGBModel"]],
|
booster: Optional[Union[Booster, "XGBModel"]],
|
||||||
|
|||||||
@ -804,10 +804,14 @@ def save_load_model(model_path):
|
|||||||
for train_index, test_index in kf.split(X, y):
|
for train_index, test_index in kf.split(X, y):
|
||||||
xgb_model = xgb.XGBClassifier(use_label_encoder=False).fit(X[train_index], y[train_index])
|
xgb_model = xgb.XGBClassifier(use_label_encoder=False).fit(X[train_index], y[train_index])
|
||||||
xgb_model.save_model(model_path)
|
xgb_model.save_model(model_path)
|
||||||
xgb_model = xgb.XGBClassifier(use_label_encoder=False)
|
|
||||||
|
xgb_model = xgb.XGBClassifier()
|
||||||
xgb_model.load_model(model_path)
|
xgb_model.load_model(model_path)
|
||||||
|
|
||||||
|
assert xgb_model.use_label_encoder is False
|
||||||
assert isinstance(xgb_model.classes_, np.ndarray)
|
assert isinstance(xgb_model.classes_, np.ndarray)
|
||||||
assert isinstance(xgb_model._Booster, xgb.Booster)
|
assert isinstance(xgb_model._Booster, xgb.Booster)
|
||||||
|
|
||||||
preds = xgb_model.predict(X[test_index])
|
preds = xgb_model.predict(X[test_index])
|
||||||
labels = y[test_index]
|
labels = y[test_index]
|
||||||
err = sum(1 for i in range(len(preds))
|
err = sum(1 for i in range(len(preds))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user