Save Scikit-Learn attributes into learner attributes. (#5245)
* Remove the recommendation for pickle. * Save skl attributes in booster.attr * Test loading scikit-learn model with native booster.
This commit is contained in:
parent
c67163250e
commit
472ded549d
@ -85,6 +85,14 @@ again after the model is loaded. If the customized function is useful, please co
|
|||||||
making a PR for implementing it inside XGBoost, this way we can have your functions
|
making a PR for implementing it inside XGBoost, this way we can have your functions
|
||||||
working with different language bindings.
|
working with different language bindings.
|
||||||
|
|
||||||
|
******************************************************
|
||||||
|
Loading pickled file from different version of XGBoost
|
||||||
|
******************************************************
|
||||||
|
|
||||||
|
As noted, pickled model is neither portable nor stable, but in some cases the pickled
|
||||||
|
models are valuable. One way to restore it in the future is to load it back with that
|
||||||
|
specific version of Python and XGBoost, export the model by calling `save_model`.
|
||||||
|
|
||||||
********************************************************
|
********************************************************
|
||||||
Saving and Loading the internal parameters configuration
|
Saving and Loading the internal parameters configuration
|
||||||
********************************************************
|
********************************************************
|
||||||
|
|||||||
@ -4,9 +4,10 @@
|
|||||||
import abc
|
import abc
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from pathlib import PurePath
|
from pathlib import PurePath
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.'
|
assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.'
|
||||||
|
|
||||||
# pylint: disable=invalid-name, redefined-builtin
|
# pylint: disable=invalid-name, redefined-builtin
|
||||||
@ -148,7 +149,29 @@ try:
|
|||||||
|
|
||||||
XGBKFold = KFold
|
XGBKFold = KFold
|
||||||
XGBStratifiedKFold = StratifiedKFold
|
XGBStratifiedKFold = StratifiedKFold
|
||||||
XGBLabelEncoder = LabelEncoder
|
|
||||||
|
class XGBoostLabelEncoder(LabelEncoder):
|
||||||
|
'''Label encoder with JSON serialization methods.'''
|
||||||
|
def to_json(self):
|
||||||
|
'''Returns a JSON compatible dictionary'''
|
||||||
|
meta = dict()
|
||||||
|
for k, v in self.__dict__.items():
|
||||||
|
if isinstance(v, np.ndarray):
|
||||||
|
meta[k] = v.tolist()
|
||||||
|
else:
|
||||||
|
meta[k] = v
|
||||||
|
return meta
|
||||||
|
|
||||||
|
def from_json(self, doc):
|
||||||
|
# pylint: disable=attribute-defined-outside-init
|
||||||
|
'''Load the encoder back from a JSON compatible dict.'''
|
||||||
|
meta = dict()
|
||||||
|
for k, v in doc.items():
|
||||||
|
if k == 'classes_':
|
||||||
|
self.classes_ = np.array(v)
|
||||||
|
continue
|
||||||
|
meta[k] = v
|
||||||
|
self.__dict__.update(meta)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
SKLEARN_INSTALLED = False
|
SKLEARN_INSTALLED = False
|
||||||
|
|
||||||
@ -159,7 +182,7 @@ except ImportError:
|
|||||||
|
|
||||||
XGBKFold = None
|
XGBKFold = None
|
||||||
XGBStratifiedKFold = None
|
XGBStratifiedKFold = None
|
||||||
XGBLabelEncoder = None
|
XGBoostLabelEncoder = None
|
||||||
|
|
||||||
|
|
||||||
# dask
|
# dask
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from .training import train
|
|||||||
# Do not use class names on scikit-learn directly. Re-define the classes on
|
# Do not use class names on scikit-learn directly. Re-define the classes on
|
||||||
# .compat to guarantee the behavior without scikit-learn
|
# .compat to guarantee the behavior without scikit-learn
|
||||||
from .compat import (SKLEARN_INSTALLED, XGBModelBase,
|
from .compat import (SKLEARN_INSTALLED, XGBModelBase,
|
||||||
XGBClassifierBase, XGBRegressorBase, XGBLabelEncoder)
|
XGBClassifierBase, XGBRegressorBase, XGBoostLabelEncoder)
|
||||||
|
|
||||||
|
|
||||||
def _objective_decorator(func):
|
def _objective_decorator(func):
|
||||||
@ -330,54 +330,96 @@ class XGBModel(XGBModelBase):
|
|||||||
"""Gets the number of xgboost boosting rounds."""
|
"""Gets the number of xgboost boosting rounds."""
|
||||||
return self.n_estimators
|
return self.n_estimators
|
||||||
|
|
||||||
def save_model(self, fname):
|
def save_model(self, fname: str):
|
||||||
"""
|
"""Save the model to a file.
|
||||||
Save the model to a file.
|
|
||||||
|
|
||||||
The model is saved in an XGBoost internal binary format which is
|
The model is saved in an XGBoost internal format which is universal
|
||||||
universal among the various XGBoost interfaces. Auxiliary attributes of
|
among the various XGBoost interfaces. Auxiliary attributes of the
|
||||||
the Python Booster object (such as feature names) will not be loaded.
|
Python Booster object (such as feature names) will not be saved.
|
||||||
Label encodings (text labels to numeric labels) will be also lost.
|
|
||||||
**If you are using only the Python interface, we recommend pickling the
|
.. note::
|
||||||
model object for best results.**
|
|
||||||
|
See:
|
||||||
|
|
||||||
|
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
fname : string
|
fname : string
|
||||||
Output file name
|
Output file name
|
||||||
|
|
||||||
"""
|
"""
|
||||||
warnings.warn("save_model: Useful attributes in the Python " +
|
meta = dict()
|
||||||
"object {} will be lost. ".format(type(self).__name__) +
|
for k, v in self.__dict__.items():
|
||||||
"If you did not mean to export the model to " +
|
if k == '_le':
|
||||||
"a non-Python binding of XGBoost, consider " +
|
meta['_le'] = self._le.to_json()
|
||||||
"using `pickle` or `joblib` to save your model.",
|
continue
|
||||||
Warning)
|
if k == '_Booster':
|
||||||
|
continue
|
||||||
|
if k == 'classes_':
|
||||||
|
# numpy array is not JSON serializable
|
||||||
|
meta['classes_'] = self.classes_.tolist()
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
json.dumps({k: v})
|
||||||
|
meta[k] = v
|
||||||
|
except TypeError:
|
||||||
|
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.')
|
||||||
|
meta['type'] = type(self).__name__
|
||||||
|
meta = json.dumps(meta)
|
||||||
|
self.get_booster().set_attr(scikit_learn=meta)
|
||||||
self.get_booster().save_model(fname)
|
self.get_booster().save_model(fname)
|
||||||
|
# Delete the attribute after save
|
||||||
|
self.get_booster().set_attr(scikit_learn=None)
|
||||||
|
|
||||||
def load_model(self, fname):
|
def load_model(self, fname):
|
||||||
"""
|
# pylint: disable=attribute-defined-outside-init
|
||||||
Load the model from a file.
|
"""Load the model from a file.
|
||||||
|
|
||||||
The model is loaded from an XGBoost internal binary format which is
|
The model is loaded from an XGBoost internal format which is universal
|
||||||
universal among the various XGBoost interfaces. Auxiliary attributes of
|
among the various XGBoost interfaces. Auxiliary attributes of the
|
||||||
the Python Booster object (such as feature names) will not be loaded.
|
Python Booster object (such as feature names) will not be loaded.
|
||||||
Label encodings (text labels to numeric labels) will be also lost.
|
|
||||||
**If you are using only the Python interface, we recommend pickling the
|
|
||||||
model object for best results.**
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
fname : string or a memory buffer
|
fname : string
|
||||||
Input file name or memory buffer(see also save_raw)
|
Input file name.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if self._Booster is None:
|
if self._Booster is None:
|
||||||
self._Booster = Booster({'n_jobs': self.n_jobs})
|
self._Booster = Booster({'n_jobs': self.n_jobs})
|
||||||
self._Booster.load_model(fname)
|
self._Booster.load_model(fname)
|
||||||
|
meta = self._Booster.attr('scikit_learn')
|
||||||
|
if meta is None:
|
||||||
|
warnings.warn(
|
||||||
|
'Loading a native XGBoost model with Scikit-Learn interface.')
|
||||||
|
return
|
||||||
|
meta = json.loads(meta)
|
||||||
|
states = dict()
|
||||||
|
for k, v in meta.items():
|
||||||
|
if k == '_le':
|
||||||
|
self._le = XGBoostLabelEncoder()
|
||||||
|
self._le.from_json(v)
|
||||||
|
continue
|
||||||
|
if k == 'classes_':
|
||||||
|
self.classes_ = np.array(v)
|
||||||
|
continue
|
||||||
|
if k == 'type' and type(self).__name__ != v:
|
||||||
|
msg = f'Current model type: {type(self).__name__}, ' + \
|
||||||
|
f'type of model in file: {v}'
|
||||||
|
raise TypeError(msg)
|
||||||
|
if k == 'type':
|
||||||
|
continue
|
||||||
|
states[k] = v
|
||||||
|
self.__dict__.update(states)
|
||||||
|
# Delete the attribute after load
|
||||||
|
self.get_booster().set_attr(scikit_learn=None)
|
||||||
|
|
||||||
def fit(self, X, y, sample_weight=None, base_margin=None,
|
def fit(self, X, y, sample_weight=None, base_margin=None,
|
||||||
eval_set=None, eval_metric=None, early_stopping_rounds=None,
|
eval_set=None, eval_metric=None, early_stopping_rounds=None,
|
||||||
verbose=True, xgb_model=None, sample_weight_eval_set=None, callbacks=None):
|
verbose=True, xgb_model=None, sample_weight_eval_set=None,
|
||||||
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
|
callbacks=None):
|
||||||
|
# pylint: disable=invalid-name,attribute-defined-outside-init
|
||||||
"""Fit gradient boosting model
|
"""Fit gradient boosting model
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -678,7 +720,7 @@ class XGBModel(XGBModelBase):
|
|||||||
"Implementation of the scikit-learn API for XGBoost classification.",
|
"Implementation of the scikit-learn API for XGBoost classification.",
|
||||||
['model', 'objective'])
|
['model', 'objective'])
|
||||||
class XGBClassifier(XGBModel, XGBClassifierBase):
|
class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||||
# pylint: disable=missing-docstring,too-many-arguments,invalid-name,too-many-instance-attributes
|
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
|
||||||
def __init__(self, objective="binary:logistic", **kwargs):
|
def __init__(self, objective="binary:logistic", **kwargs):
|
||||||
super().__init__(objective=objective, **kwargs)
|
super().__init__(objective=objective, **kwargs)
|
||||||
|
|
||||||
@ -714,7 +756,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
else:
|
else:
|
||||||
xgb_options.update({"eval_metric": eval_metric})
|
xgb_options.update({"eval_metric": eval_metric})
|
||||||
|
|
||||||
self._le = XGBLabelEncoder().fit(y)
|
self._le = XGBoostLabelEncoder().fit(y)
|
||||||
training_labels = self._le.transform(y)
|
training_labels = self._le.transform(y)
|
||||||
|
|
||||||
if eval_set is not None:
|
if eval_set is not None:
|
||||||
@ -809,7 +851,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
missing=self.missing, nthread=self.n_jobs)
|
missing=self.missing, nthread=self.n_jobs)
|
||||||
if ntree_limit is None:
|
if ntree_limit is None:
|
||||||
ntree_limit = getattr(self, "best_ntree_limit", 0)
|
ntree_limit = getattr(self, "best_ntree_limit", 0)
|
||||||
class_probs = self.get_booster().predict(test_dmatrix,
|
class_probs = self.get_booster().predict(
|
||||||
|
test_dmatrix,
|
||||||
output_margin=output_margin,
|
output_margin=output_margin,
|
||||||
ntree_limit=ntree_limit,
|
ntree_limit=ntree_limit,
|
||||||
validate_features=validate_features)
|
validate_features=validate_features)
|
||||||
@ -822,7 +865,12 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
else:
|
else:
|
||||||
column_indexes = np.repeat(0, class_probs.shape[0])
|
column_indexes = np.repeat(0, class_probs.shape[0])
|
||||||
column_indexes[class_probs > 0.5] = 1
|
column_indexes[class_probs > 0.5] = 1
|
||||||
|
|
||||||
|
if hasattr(self, '_le'):
|
||||||
return self._le.inverse_transform(column_indexes)
|
return self._le.inverse_transform(column_indexes)
|
||||||
|
warnings.warn(
|
||||||
|
'Label encoder is not defined. Returning class probability.')
|
||||||
|
return class_probs
|
||||||
|
|
||||||
def predict_proba(self, data, ntree_limit=None, validate_features=True,
|
def predict_proba(self, data, ntree_limit=None, validate_features=True,
|
||||||
base_margin=None):
|
base_margin=None):
|
||||||
|
|||||||
@ -52,7 +52,7 @@ class PeekableInStream : public dmlc::Stream {
|
|||||||
class FixedSizeStream : public PeekableInStream {
|
class FixedSizeStream : public PeekableInStream {
|
||||||
public:
|
public:
|
||||||
explicit FixedSizeStream(PeekableInStream* stream);
|
explicit FixedSizeStream(PeekableInStream* stream);
|
||||||
~FixedSizeStream() = default;
|
~FixedSizeStream() override = default;
|
||||||
|
|
||||||
size_t Read(void* dptr, size_t size) override;
|
size_t Read(void* dptr, size_t size) override;
|
||||||
size_t PeekRead(void* dptr, size_t size) override;
|
size_t PeekRead(void* dptr, size_t size) override;
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright (c) by Contributors 2019
|
* Copyright (c) by Contributors 2019
|
||||||
*/
|
*/
|
||||||
|
#include <cctype>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
@ -351,7 +352,9 @@ Json JsonReader::Parse() {
|
|||||||
return ParseObject();
|
return ParseObject();
|
||||||
} else if ( c == '[' ) {
|
} else if ( c == '[' ) {
|
||||||
return ParseArray();
|
return ParseArray();
|
||||||
} else if ( c == '-' || std::isdigit(c) ) {
|
} else if ( c == '-' || std::isdigit(c) ||
|
||||||
|
c == 'N' ) {
|
||||||
|
// For now we only accept `NaN`, not `nan` as the later violiates LR(1) with `null`.
|
||||||
return ParseNumber();
|
return ParseNumber();
|
||||||
} else if ( c == '\"' ) {
|
} else if ( c == '\"' ) {
|
||||||
return ParseString();
|
return ParseString();
|
||||||
@ -547,6 +550,13 @@ Json JsonReader::ParseNumber() {
|
|||||||
|
|
||||||
// TODO(trivialfis): Add back all the checks for number
|
// TODO(trivialfis): Add back all the checks for number
|
||||||
bool negative = false;
|
bool negative = false;
|
||||||
|
if (XGBOOST_EXPECT(*p == 'N', false)) {
|
||||||
|
GetChar('N');
|
||||||
|
GetChar('a');
|
||||||
|
GetChar('N');
|
||||||
|
return Json(static_cast<Number::Float>(std::numeric_limits<float>::quiet_NaN()));
|
||||||
|
}
|
||||||
|
|
||||||
if ('-' == *p) {
|
if ('-' == *p) {
|
||||||
++p;
|
++p;
|
||||||
negative = true;
|
negative = true;
|
||||||
|
|||||||
@ -661,9 +661,9 @@ class LearnerImpl : public Learner {
|
|||||||
CHECK(header == serialisation_header_) // NOLINT
|
CHECK(header == serialisation_header_) // NOLINT
|
||||||
<< R"doc(
|
<< R"doc(
|
||||||
|
|
||||||
If you are loading a serialized model (like pickle in Python) generated by older XGBoost,
|
If you are loading a serialized model (like pickle in Python) generated by older
|
||||||
please export the model by calling `Booster.save_model` from that version first, then load
|
XGBoost, please export the model by calling `Booster.save_model` from that version
|
||||||
it back in current version. See:
|
first, then load it back in current version. See:
|
||||||
|
|
||||||
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
|
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
|
||||||
|
|
||||||
|
|||||||
@ -30,7 +30,8 @@ def json_model(model_path, parameters):
|
|||||||
class TestModels(unittest.TestCase):
|
class TestModels(unittest.TestCase):
|
||||||
def test_glm(self):
|
def test_glm(self):
|
||||||
param = {'verbosity': 0, 'objective': 'binary:logistic',
|
param = {'verbosity': 0, 'objective': 'binary:logistic',
|
||||||
'booster': 'gblinear', 'alpha': 0.0001, 'lambda': 1, 'nthread': 1}
|
'booster': 'gblinear', 'alpha': 0.0001, 'lambda': 1,
|
||||||
|
'nthread': 1}
|
||||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||||
num_round = 4
|
num_round = 4
|
||||||
bst = xgb.train(param, dtrain, num_round, watchlist)
|
bst = xgb.train(param, dtrain, num_round, watchlist)
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
|
from xgboost.sklearn import XGBoostLabelEncoder
|
||||||
import testing as tm
|
import testing as tm
|
||||||
import tempfile
|
import tempfile
|
||||||
import os
|
import os
|
||||||
@ -614,7 +615,7 @@ def test_validation_weights_xgbclassifier():
|
|||||||
for i in [0, 1]))
|
for i in [0, 1]))
|
||||||
|
|
||||||
|
|
||||||
def test_save_load_model():
|
def save_load_model(model_path):
|
||||||
from sklearn.datasets import load_digits
|
from sklearn.datasets import load_digits
|
||||||
from sklearn.model_selection import KFold
|
from sklearn.model_selection import KFold
|
||||||
|
|
||||||
@ -622,18 +623,64 @@ def test_save_load_model():
|
|||||||
y = digits['target']
|
y = digits['target']
|
||||||
X = digits['data']
|
X = digits['data']
|
||||||
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
||||||
with TemporaryDirectory() as tempdir:
|
|
||||||
model_path = os.path.join(tempdir, 'digits.model')
|
|
||||||
for train_index, test_index in kf.split(X, y):
|
for train_index, test_index in kf.split(X, y):
|
||||||
xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index])
|
xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index])
|
||||||
xgb_model.save_model(model_path)
|
xgb_model.save_model(model_path)
|
||||||
xgb_model = xgb.XGBModel()
|
xgb_model = xgb.XGBClassifier()
|
||||||
xgb_model.load_model(model_path)
|
xgb_model.load_model(model_path)
|
||||||
|
assert isinstance(xgb_model.classes_, np.ndarray)
|
||||||
|
assert isinstance(xgb_model._Booster, xgb.Booster)
|
||||||
|
assert isinstance(xgb_model._le, XGBoostLabelEncoder)
|
||||||
|
assert isinstance(xgb_model._le.classes_, np.ndarray)
|
||||||
preds = xgb_model.predict(X[test_index])
|
preds = xgb_model.predict(X[test_index])
|
||||||
labels = y[test_index]
|
labels = y[test_index]
|
||||||
err = sum(1 for i in range(len(preds))
|
err = sum(1 for i in range(len(preds))
|
||||||
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
|
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
|
||||||
assert err < 0.1
|
assert err < 0.1
|
||||||
|
assert xgb_model.get_booster().attr('scikit_learn') is None
|
||||||
|
|
||||||
|
# test native booster
|
||||||
|
preds = xgb_model.predict(X[test_index], output_margin=True)
|
||||||
|
booster = xgb.Booster(model_file=model_path)
|
||||||
|
predt_1 = booster.predict(xgb.DMatrix(X[test_index]),
|
||||||
|
output_margin=True)
|
||||||
|
assert np.allclose(preds, predt_1)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
xgb_model = xgb.XGBModel()
|
||||||
|
xgb_model.load_model(model_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_load_model():
|
||||||
|
with TemporaryDirectory() as tempdir:
|
||||||
|
model_path = os.path.join(tempdir, 'digits.model')
|
||||||
|
save_load_model(model_path)
|
||||||
|
|
||||||
|
with TemporaryDirectory() as tempdir:
|
||||||
|
model_path = os.path.join(tempdir, 'digits.model.json')
|
||||||
|
save_load_model(model_path)
|
||||||
|
|
||||||
|
from sklearn.datasets import load_digits
|
||||||
|
with TemporaryDirectory() as tempdir:
|
||||||
|
model_path = os.path.join(tempdir, 'digits.model.json')
|
||||||
|
digits = load_digits(2)
|
||||||
|
y = digits['target']
|
||||||
|
X = digits['data']
|
||||||
|
booster = xgb.train({'tree_method': 'hist',
|
||||||
|
'objective': 'binary:logistic'},
|
||||||
|
dtrain=xgb.DMatrix(X, y),
|
||||||
|
num_boost_round=4)
|
||||||
|
predt_0 = booster.predict(xgb.DMatrix(X))
|
||||||
|
booster.save_model(model_path)
|
||||||
|
cls = xgb.XGBClassifier()
|
||||||
|
cls.load_model(model_path)
|
||||||
|
predt_1 = cls.predict(X)
|
||||||
|
assert np.allclose(predt_0, predt_1)
|
||||||
|
|
||||||
|
cls = xgb.XGBModel()
|
||||||
|
cls.load_model(model_path)
|
||||||
|
predt_1 = cls.predict(X)
|
||||||
|
assert np.allclose(predt_0, predt_1)
|
||||||
|
|
||||||
|
|
||||||
def test_RFECV():
|
def test_RFECV():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user