Save Scikit-Learn attributes into learner attributes. (#5245)

* Remove the recommendation for pickle.

* Save skl attributes in booster.attr

* Test loading scikit-learn model with native booster.
This commit is contained in:
Jiaming Yuan 2020-01-30 16:00:18 +08:00 committed by GitHub
parent c67163250e
commit 472ded549d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 194 additions and 57 deletions

View File

@ -85,6 +85,14 @@ again after the model is loaded. If the customized function is useful, please co
making a PR for implementing it inside XGBoost, this way we can have your functions making a PR for implementing it inside XGBoost, this way we can have your functions
working with different language bindings. working with different language bindings.
******************************************************
Loading pickled file from different version of XGBoost
******************************************************
As noted, pickled model is neither portable nor stable, but in some cases the pickled
models are valuable. One way to restore it in the future is to load it back with that
specific version of Python and XGBoost, export the model by calling `save_model`.
******************************************************** ********************************************************
Saving and Loading the internal parameters configuration Saving and Loading the internal parameters configuration
******************************************************** ********************************************************

View File

@ -4,9 +4,10 @@
import abc import abc
import os import os
import sys import sys
from pathlib import PurePath from pathlib import PurePath
import numpy as np
assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.' assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.'
# pylint: disable=invalid-name, redefined-builtin # pylint: disable=invalid-name, redefined-builtin
@ -148,7 +149,29 @@ try:
XGBKFold = KFold XGBKFold = KFold
XGBStratifiedKFold = StratifiedKFold XGBStratifiedKFold = StratifiedKFold
XGBLabelEncoder = LabelEncoder
class XGBoostLabelEncoder(LabelEncoder):
'''Label encoder with JSON serialization methods.'''
def to_json(self):
'''Returns a JSON compatible dictionary'''
meta = dict()
for k, v in self.__dict__.items():
if isinstance(v, np.ndarray):
meta[k] = v.tolist()
else:
meta[k] = v
return meta
def from_json(self, doc):
# pylint: disable=attribute-defined-outside-init
'''Load the encoder back from a JSON compatible dict.'''
meta = dict()
for k, v in doc.items():
if k == 'classes_':
self.classes_ = np.array(v)
continue
meta[k] = v
self.__dict__.update(meta)
except ImportError: except ImportError:
SKLEARN_INSTALLED = False SKLEARN_INSTALLED = False
@ -159,7 +182,7 @@ except ImportError:
XGBKFold = None XGBKFold = None
XGBStratifiedKFold = None XGBStratifiedKFold = None
XGBLabelEncoder = None XGBoostLabelEncoder = None
# dask # dask

View File

@ -11,7 +11,7 @@ from .training import train
# Do not use class names on scikit-learn directly. Re-define the classes on # Do not use class names on scikit-learn directly. Re-define the classes on
# .compat to guarantee the behavior without scikit-learn # .compat to guarantee the behavior without scikit-learn
from .compat import (SKLEARN_INSTALLED, XGBModelBase, from .compat import (SKLEARN_INSTALLED, XGBModelBase,
XGBClassifierBase, XGBRegressorBase, XGBLabelEncoder) XGBClassifierBase, XGBRegressorBase, XGBoostLabelEncoder)
def _objective_decorator(func): def _objective_decorator(func):
@ -330,54 +330,96 @@ class XGBModel(XGBModelBase):
"""Gets the number of xgboost boosting rounds.""" """Gets the number of xgboost boosting rounds."""
return self.n_estimators return self.n_estimators
def save_model(self, fname): def save_model(self, fname: str):
""" """Save the model to a file.
Save the model to a file.
The model is saved in an XGBoost internal binary format which is The model is saved in an XGBoost internal format which is universal
universal among the various XGBoost interfaces. Auxiliary attributes of among the various XGBoost interfaces. Auxiliary attributes of the
the Python Booster object (such as feature names) will not be loaded. Python Booster object (such as feature names) will not be saved.
Label encodings (text labels to numeric labels) will be also lost.
**If you are using only the Python interface, we recommend pickling the .. note::
model object for best results.**
See:
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
Parameters Parameters
---------- ----------
fname : string fname : string
Output file name Output file name
""" """
warnings.warn("save_model: Useful attributes in the Python " + meta = dict()
"object {} will be lost. ".format(type(self).__name__) + for k, v in self.__dict__.items():
"If you did not mean to export the model to " + if k == '_le':
"a non-Python binding of XGBoost, consider " + meta['_le'] = self._le.to_json()
"using `pickle` or `joblib` to save your model.", continue
Warning) if k == '_Booster':
continue
if k == 'classes_':
# numpy array is not JSON serializable
meta['classes_'] = self.classes_.tolist()
continue
try:
json.dumps({k: v})
meta[k] = v
except TypeError:
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.')
meta['type'] = type(self).__name__
meta = json.dumps(meta)
self.get_booster().set_attr(scikit_learn=meta)
self.get_booster().save_model(fname) self.get_booster().save_model(fname)
# Delete the attribute after save
self.get_booster().set_attr(scikit_learn=None)
def load_model(self, fname): def load_model(self, fname):
""" # pylint: disable=attribute-defined-outside-init
Load the model from a file. """Load the model from a file.
The model is loaded from an XGBoost internal binary format which is The model is loaded from an XGBoost internal format which is universal
universal among the various XGBoost interfaces. Auxiliary attributes of among the various XGBoost interfaces. Auxiliary attributes of the
the Python Booster object (such as feature names) will not be loaded. Python Booster object (such as feature names) will not be loaded.
Label encodings (text labels to numeric labels) will be also lost.
**If you are using only the Python interface, we recommend pickling the
model object for best results.**
Parameters Parameters
---------- ----------
fname : string or a memory buffer fname : string
Input file name or memory buffer(see also save_raw) Input file name.
""" """
if self._Booster is None: if self._Booster is None:
self._Booster = Booster({'n_jobs': self.n_jobs}) self._Booster = Booster({'n_jobs': self.n_jobs})
self._Booster.load_model(fname) self._Booster.load_model(fname)
meta = self._Booster.attr('scikit_learn')
if meta is None:
warnings.warn(
'Loading a native XGBoost model with Scikit-Learn interface.')
return
meta = json.loads(meta)
states = dict()
for k, v in meta.items():
if k == '_le':
self._le = XGBoostLabelEncoder()
self._le.from_json(v)
continue
if k == 'classes_':
self.classes_ = np.array(v)
continue
if k == 'type' and type(self).__name__ != v:
msg = f'Current model type: {type(self).__name__}, ' + \
f'type of model in file: {v}'
raise TypeError(msg)
if k == 'type':
continue
states[k] = v
self.__dict__.update(states)
# Delete the attribute after load
self.get_booster().set_attr(scikit_learn=None)
def fit(self, X, y, sample_weight=None, base_margin=None, def fit(self, X, y, sample_weight=None, base_margin=None,
eval_set=None, eval_metric=None, early_stopping_rounds=None, eval_set=None, eval_metric=None, early_stopping_rounds=None,
verbose=True, xgb_model=None, sample_weight_eval_set=None, callbacks=None): verbose=True, xgb_model=None, sample_weight_eval_set=None,
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init callbacks=None):
# pylint: disable=invalid-name,attribute-defined-outside-init
"""Fit gradient boosting model """Fit gradient boosting model
Parameters Parameters
@ -678,7 +720,7 @@ class XGBModel(XGBModelBase):
"Implementation of the scikit-learn API for XGBoost classification.", "Implementation of the scikit-learn API for XGBoost classification.",
['model', 'objective']) ['model', 'objective'])
class XGBClassifier(XGBModel, XGBClassifierBase): class XGBClassifier(XGBModel, XGBClassifierBase):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name,too-many-instance-attributes # pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
def __init__(self, objective="binary:logistic", **kwargs): def __init__(self, objective="binary:logistic", **kwargs):
super().__init__(objective=objective, **kwargs) super().__init__(objective=objective, **kwargs)
@ -714,7 +756,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
else: else:
xgb_options.update({"eval_metric": eval_metric}) xgb_options.update({"eval_metric": eval_metric})
self._le = XGBLabelEncoder().fit(y) self._le = XGBoostLabelEncoder().fit(y)
training_labels = self._le.transform(y) training_labels = self._le.transform(y)
if eval_set is not None: if eval_set is not None:
@ -809,7 +851,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
missing=self.missing, nthread=self.n_jobs) missing=self.missing, nthread=self.n_jobs)
if ntree_limit is None: if ntree_limit is None:
ntree_limit = getattr(self, "best_ntree_limit", 0) ntree_limit = getattr(self, "best_ntree_limit", 0)
class_probs = self.get_booster().predict(test_dmatrix, class_probs = self.get_booster().predict(
test_dmatrix,
output_margin=output_margin, output_margin=output_margin,
ntree_limit=ntree_limit, ntree_limit=ntree_limit,
validate_features=validate_features) validate_features=validate_features)
@ -822,7 +865,12 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
else: else:
column_indexes = np.repeat(0, class_probs.shape[0]) column_indexes = np.repeat(0, class_probs.shape[0])
column_indexes[class_probs > 0.5] = 1 column_indexes[class_probs > 0.5] = 1
if hasattr(self, '_le'):
return self._le.inverse_transform(column_indexes) return self._le.inverse_transform(column_indexes)
warnings.warn(
'Label encoder is not defined. Returning class probability.')
return class_probs
def predict_proba(self, data, ntree_limit=None, validate_features=True, def predict_proba(self, data, ntree_limit=None, validate_features=True,
base_margin=None): base_margin=None):

View File

@ -52,7 +52,7 @@ class PeekableInStream : public dmlc::Stream {
class FixedSizeStream : public PeekableInStream { class FixedSizeStream : public PeekableInStream {
public: public:
explicit FixedSizeStream(PeekableInStream* stream); explicit FixedSizeStream(PeekableInStream* stream);
~FixedSizeStream() = default; ~FixedSizeStream() override = default;
size_t Read(void* dptr, size_t size) override; size_t Read(void* dptr, size_t size) override;
size_t PeekRead(void* dptr, size_t size) override; size_t PeekRead(void* dptr, size_t size) override;

View File

@ -1,6 +1,7 @@
/*! /*!
* Copyright (c) by Contributors 2019 * Copyright (c) by Contributors 2019
*/ */
#include <cctype>
#include <sstream> #include <sstream>
#include <limits> #include <limits>
#include <cmath> #include <cmath>
@ -351,7 +352,9 @@ Json JsonReader::Parse() {
return ParseObject(); return ParseObject();
} else if ( c == '[' ) { } else if ( c == '[' ) {
return ParseArray(); return ParseArray();
} else if ( c == '-' || std::isdigit(c) ) { } else if ( c == '-' || std::isdigit(c) ||
c == 'N' ) {
// For now we only accept `NaN`, not `nan` as the later violiates LR(1) with `null`.
return ParseNumber(); return ParseNumber();
} else if ( c == '\"' ) { } else if ( c == '\"' ) {
return ParseString(); return ParseString();
@ -547,6 +550,13 @@ Json JsonReader::ParseNumber() {
// TODO(trivialfis): Add back all the checks for number // TODO(trivialfis): Add back all the checks for number
bool negative = false; bool negative = false;
if (XGBOOST_EXPECT(*p == 'N', false)) {
GetChar('N');
GetChar('a');
GetChar('N');
return Json(static_cast<Number::Float>(std::numeric_limits<float>::quiet_NaN()));
}
if ('-' == *p) { if ('-' == *p) {
++p; ++p;
negative = true; negative = true;

View File

@ -661,9 +661,9 @@ class LearnerImpl : public Learner {
CHECK(header == serialisation_header_) // NOLINT CHECK(header == serialisation_header_) // NOLINT
<< R"doc( << R"doc(
If you are loading a serialized model (like pickle in Python) generated by older XGBoost, If you are loading a serialized model (like pickle in Python) generated by older
please export the model by calling `Booster.save_model` from that version first, then load XGBoost, please export the model by calling `Booster.save_model` from that version
it back in current version. See: first, then load it back in current version. See:
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html

View File

@ -30,7 +30,8 @@ def json_model(model_path, parameters):
class TestModels(unittest.TestCase): class TestModels(unittest.TestCase):
def test_glm(self): def test_glm(self):
param = {'verbosity': 0, 'objective': 'binary:logistic', param = {'verbosity': 0, 'objective': 'binary:logistic',
'booster': 'gblinear', 'alpha': 0.0001, 'lambda': 1, 'nthread': 1} 'booster': 'gblinear', 'alpha': 0.0001, 'lambda': 1,
'nthread': 1}
watchlist = [(dtest, 'eval'), (dtrain, 'train')] watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 4 num_round = 4
bst = xgb.train(param, dtrain, num_round, watchlist) bst = xgb.train(param, dtrain, num_round, watchlist)

View File

@ -1,5 +1,6 @@
import numpy as np import numpy as np
import xgboost as xgb import xgboost as xgb
from xgboost.sklearn import XGBoostLabelEncoder
import testing as tm import testing as tm
import tempfile import tempfile
import os import os
@ -614,7 +615,7 @@ def test_validation_weights_xgbclassifier():
for i in [0, 1])) for i in [0, 1]))
def test_save_load_model(): def save_load_model(model_path):
from sklearn.datasets import load_digits from sklearn.datasets import load_digits
from sklearn.model_selection import KFold from sklearn.model_selection import KFold
@ -622,18 +623,64 @@ def test_save_load_model():
y = digits['target'] y = digits['target']
X = digits['data'] X = digits['data']
kf = KFold(n_splits=2, shuffle=True, random_state=rng) kf = KFold(n_splits=2, shuffle=True, random_state=rng)
with TemporaryDirectory() as tempdir:
model_path = os.path.join(tempdir, 'digits.model')
for train_index, test_index in kf.split(X, y): for train_index, test_index in kf.split(X, y):
xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index]) xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index])
xgb_model.save_model(model_path) xgb_model.save_model(model_path)
xgb_model = xgb.XGBModel() xgb_model = xgb.XGBClassifier()
xgb_model.load_model(model_path) xgb_model.load_model(model_path)
assert isinstance(xgb_model.classes_, np.ndarray)
assert isinstance(xgb_model._Booster, xgb.Booster)
assert isinstance(xgb_model._le, XGBoostLabelEncoder)
assert isinstance(xgb_model._le.classes_, np.ndarray)
preds = xgb_model.predict(X[test_index]) preds = xgb_model.predict(X[test_index])
labels = y[test_index] labels = y[test_index]
err = sum(1 for i in range(len(preds)) err = sum(1 for i in range(len(preds))
if int(preds[i] > 0.5) != labels[i]) / float(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
assert err < 0.1 assert err < 0.1
assert xgb_model.get_booster().attr('scikit_learn') is None
# test native booster
preds = xgb_model.predict(X[test_index], output_margin=True)
booster = xgb.Booster(model_file=model_path)
predt_1 = booster.predict(xgb.DMatrix(X[test_index]),
output_margin=True)
assert np.allclose(preds, predt_1)
with pytest.raises(TypeError):
xgb_model = xgb.XGBModel()
xgb_model.load_model(model_path)
def test_save_load_model():
with TemporaryDirectory() as tempdir:
model_path = os.path.join(tempdir, 'digits.model')
save_load_model(model_path)
with TemporaryDirectory() as tempdir:
model_path = os.path.join(tempdir, 'digits.model.json')
save_load_model(model_path)
from sklearn.datasets import load_digits
with TemporaryDirectory() as tempdir:
model_path = os.path.join(tempdir, 'digits.model.json')
digits = load_digits(2)
y = digits['target']
X = digits['data']
booster = xgb.train({'tree_method': 'hist',
'objective': 'binary:logistic'},
dtrain=xgb.DMatrix(X, y),
num_boost_round=4)
predt_0 = booster.predict(xgb.DMatrix(X))
booster.save_model(model_path)
cls = xgb.XGBClassifier()
cls.load_model(model_path)
predt_1 = cls.predict(X)
assert np.allclose(predt_0, predt_1)
cls = xgb.XGBModel()
cls.load_model(model_path)
predt_1 = cls.predict(X)
assert np.allclose(predt_0, predt_1)
def test_RFECV(): def test_RFECV():