diff --git a/doc/tutorials/saving_model.rst b/doc/tutorials/saving_model.rst index de605f326..aa3b41e6b 100644 --- a/doc/tutorials/saving_model.rst +++ b/doc/tutorials/saving_model.rst @@ -85,6 +85,14 @@ again after the model is loaded. If the customized function is useful, please co making a PR for implementing it inside XGBoost, this way we can have your functions working with different language bindings. +****************************************************** +Loading pickled file from different version of XGBoost +****************************************************** + +As noted, pickled model is neither portable nor stable, but in some cases the pickled +models are valuable. One way to restore it in the future is to load it back with that +specific version of Python and XGBoost, export the model by calling `save_model`. + ******************************************************** Saving and Loading the internal parameters configuration ******************************************************** diff --git a/python-package/xgboost/compat.py b/python-package/xgboost/compat.py index 7ba627c14..bae283de5 100644 --- a/python-package/xgboost/compat.py +++ b/python-package/xgboost/compat.py @@ -4,9 +4,10 @@ import abc import os import sys - from pathlib import PurePath +import numpy as np + assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.' # pylint: disable=invalid-name, redefined-builtin @@ -148,7 +149,29 @@ try: XGBKFold = KFold XGBStratifiedKFold = StratifiedKFold - XGBLabelEncoder = LabelEncoder + + class XGBoostLabelEncoder(LabelEncoder): + '''Label encoder with JSON serialization methods.''' + def to_json(self): + '''Returns a JSON compatible dictionary''' + meta = dict() + for k, v in self.__dict__.items(): + if isinstance(v, np.ndarray): + meta[k] = v.tolist() + else: + meta[k] = v + return meta + + def from_json(self, doc): + # pylint: disable=attribute-defined-outside-init + '''Load the encoder back from a JSON compatible dict.''' + meta = dict() + for k, v in doc.items(): + if k == 'classes_': + self.classes_ = np.array(v) + continue + meta[k] = v + self.__dict__.update(meta) except ImportError: SKLEARN_INSTALLED = False @@ -159,7 +182,7 @@ except ImportError: XGBKFold = None XGBStratifiedKFold = None - XGBLabelEncoder = None + XGBoostLabelEncoder = None # dask diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 1def0dbf2..d59b93cbc 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -11,7 +11,7 @@ from .training import train # Do not use class names on scikit-learn directly. Re-define the classes on # .compat to guarantee the behavior without scikit-learn from .compat import (SKLEARN_INSTALLED, XGBModelBase, - XGBClassifierBase, XGBRegressorBase, XGBLabelEncoder) + XGBClassifierBase, XGBRegressorBase, XGBoostLabelEncoder) def _objective_decorator(func): @@ -330,54 +330,96 @@ class XGBModel(XGBModelBase): """Gets the number of xgboost boosting rounds.""" return self.n_estimators - def save_model(self, fname): - """ - Save the model to a file. + def save_model(self, fname: str): + """Save the model to a file. - The model is saved in an XGBoost internal binary format which is - universal among the various XGBoost interfaces. Auxiliary attributes of - the Python Booster object (such as feature names) will not be loaded. - Label encodings (text labels to numeric labels) will be also lost. - **If you are using only the Python interface, we recommend pickling the - model object for best results.** + The model is saved in an XGBoost internal format which is universal + among the various XGBoost interfaces. Auxiliary attributes of the + Python Booster object (such as feature names) will not be saved. + + .. note:: + + See: + + https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html Parameters ---------- fname : string Output file name + """ - warnings.warn("save_model: Useful attributes in the Python " + - "object {} will be lost. ".format(type(self).__name__) + - "If you did not mean to export the model to " + - "a non-Python binding of XGBoost, consider " + - "using `pickle` or `joblib` to save your model.", - Warning) + meta = dict() + for k, v in self.__dict__.items(): + if k == '_le': + meta['_le'] = self._le.to_json() + continue + if k == '_Booster': + continue + if k == 'classes_': + # numpy array is not JSON serializable + meta['classes_'] = self.classes_.tolist() + continue + try: + json.dumps({k: v}) + meta[k] = v + except TypeError: + warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.') + meta['type'] = type(self).__name__ + meta = json.dumps(meta) + self.get_booster().set_attr(scikit_learn=meta) self.get_booster().save_model(fname) + # Delete the attribute after save + self.get_booster().set_attr(scikit_learn=None) def load_model(self, fname): - """ - Load the model from a file. + # pylint: disable=attribute-defined-outside-init + """Load the model from a file. - The model is loaded from an XGBoost internal binary format which is - universal among the various XGBoost interfaces. Auxiliary attributes of - the Python Booster object (such as feature names) will not be loaded. - Label encodings (text labels to numeric labels) will be also lost. - **If you are using only the Python interface, we recommend pickling the - model object for best results.** + The model is loaded from an XGBoost internal format which is universal + among the various XGBoost interfaces. Auxiliary attributes of the + Python Booster object (such as feature names) will not be loaded. Parameters ---------- - fname : string or a memory buffer - Input file name or memory buffer(see also save_raw) + fname : string + Input file name. + """ if self._Booster is None: self._Booster = Booster({'n_jobs': self.n_jobs}) self._Booster.load_model(fname) + meta = self._Booster.attr('scikit_learn') + if meta is None: + warnings.warn( + 'Loading a native XGBoost model with Scikit-Learn interface.') + return + meta = json.loads(meta) + states = dict() + for k, v in meta.items(): + if k == '_le': + self._le = XGBoostLabelEncoder() + self._le.from_json(v) + continue + if k == 'classes_': + self.classes_ = np.array(v) + continue + if k == 'type' and type(self).__name__ != v: + msg = f'Current model type: {type(self).__name__}, ' + \ + f'type of model in file: {v}' + raise TypeError(msg) + if k == 'type': + continue + states[k] = v + self.__dict__.update(states) + # Delete the attribute after load + self.get_booster().set_attr(scikit_learn=None) def fit(self, X, y, sample_weight=None, base_margin=None, eval_set=None, eval_metric=None, early_stopping_rounds=None, - verbose=True, xgb_model=None, sample_weight_eval_set=None, callbacks=None): - # pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init + verbose=True, xgb_model=None, sample_weight_eval_set=None, + callbacks=None): + # pylint: disable=invalid-name,attribute-defined-outside-init """Fit gradient boosting model Parameters @@ -678,7 +720,7 @@ class XGBModel(XGBModelBase): "Implementation of the scikit-learn API for XGBoost classification.", ['model', 'objective']) class XGBClassifier(XGBModel, XGBClassifierBase): - # pylint: disable=missing-docstring,too-many-arguments,invalid-name,too-many-instance-attributes + # pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes def __init__(self, objective="binary:logistic", **kwargs): super().__init__(objective=objective, **kwargs) @@ -714,7 +756,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase): else: xgb_options.update({"eval_metric": eval_metric}) - self._le = XGBLabelEncoder().fit(y) + self._le = XGBoostLabelEncoder().fit(y) training_labels = self._le.transform(y) if eval_set is not None: @@ -809,10 +851,11 @@ class XGBClassifier(XGBModel, XGBClassifierBase): missing=self.missing, nthread=self.n_jobs) if ntree_limit is None: ntree_limit = getattr(self, "best_ntree_limit", 0) - class_probs = self.get_booster().predict(test_dmatrix, - output_margin=output_margin, - ntree_limit=ntree_limit, - validate_features=validate_features) + class_probs = self.get_booster().predict( + test_dmatrix, + output_margin=output_margin, + ntree_limit=ntree_limit, + validate_features=validate_features) if output_margin: # If output_margin is active, simply return the scores return class_probs @@ -822,7 +865,12 @@ class XGBClassifier(XGBModel, XGBClassifierBase): else: column_indexes = np.repeat(0, class_probs.shape[0]) column_indexes[class_probs > 0.5] = 1 - return self._le.inverse_transform(column_indexes) + + if hasattr(self, '_le'): + return self._le.inverse_transform(column_indexes) + warnings.warn( + 'Label encoder is not defined. Returning class probability.') + return class_probs def predict_proba(self, data, ntree_limit=None, validate_features=True, base_margin=None): diff --git a/src/common/io.h b/src/common/io.h index 193239fbd..528296dc7 100644 --- a/src/common/io.h +++ b/src/common/io.h @@ -52,7 +52,7 @@ class PeekableInStream : public dmlc::Stream { class FixedSizeStream : public PeekableInStream { public: explicit FixedSizeStream(PeekableInStream* stream); - ~FixedSizeStream() = default; + ~FixedSizeStream() override = default; size_t Read(void* dptr, size_t size) override; size_t PeekRead(void* dptr, size_t size) override; diff --git a/src/common/json.cc b/src/common/json.cc index e2517c38e..ecdcce3d3 100644 --- a/src/common/json.cc +++ b/src/common/json.cc @@ -1,6 +1,7 @@ /*! * Copyright (c) by Contributors 2019 */ +#include #include #include #include @@ -351,7 +352,9 @@ Json JsonReader::Parse() { return ParseObject(); } else if ( c == '[' ) { return ParseArray(); - } else if ( c == '-' || std::isdigit(c) ) { + } else if ( c == '-' || std::isdigit(c) || + c == 'N' ) { + // For now we only accept `NaN`, not `nan` as the later violiates LR(1) with `null`. return ParseNumber(); } else if ( c == '\"' ) { return ParseString(); @@ -547,6 +550,13 @@ Json JsonReader::ParseNumber() { // TODO(trivialfis): Add back all the checks for number bool negative = false; + if (XGBOOST_EXPECT(*p == 'N', false)) { + GetChar('N'); + GetChar('a'); + GetChar('N'); + return Json(static_cast(std::numeric_limits::quiet_NaN())); + } + if ('-' == *p) { ++p; negative = true; diff --git a/src/learner.cc b/src/learner.cc index 86afaba61..c3cca845e 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -661,13 +661,13 @@ class LearnerImpl : public Learner { CHECK(header == serialisation_header_) // NOLINT << R"doc( -If you are loading a serialized model (like pickle in Python) generated by older XGBoost, -please export the model by calling `Booster.save_model` from that version first, then load -it back in current version. See: + If you are loading a serialized model (like pickle in Python) generated by older + XGBoost, please export the model by calling `Booster.save_model` from that version + first, then load it back in current version. See: - https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html + https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html -for more details about differences between saving model and serializing. + for more details about differences between saving model and serializing. )doc"; int64_t json_offset {-1}; diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index 560944d8a..a72071b20 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -30,7 +30,8 @@ def json_model(model_path, parameters): class TestModels(unittest.TestCase): def test_glm(self): param = {'verbosity': 0, 'objective': 'binary:logistic', - 'booster': 'gblinear', 'alpha': 0.0001, 'lambda': 1, 'nthread': 1} + 'booster': 'gblinear', 'alpha': 0.0001, 'lambda': 1, + 'nthread': 1} watchlist = [(dtest, 'eval'), (dtrain, 'train')] num_round = 4 bst = xgb.train(param, dtrain, num_round, watchlist) diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index df475f78d..58704c0cc 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -1,5 +1,6 @@ import numpy as np import xgboost as xgb +from xgboost.sklearn import XGBoostLabelEncoder import testing as tm import tempfile import os @@ -614,7 +615,7 @@ def test_validation_weights_xgbclassifier(): for i in [0, 1])) -def test_save_load_model(): +def save_load_model(model_path): from sklearn.datasets import load_digits from sklearn.model_selection import KFold @@ -622,18 +623,64 @@ def test_save_load_model(): y = digits['target'] X = digits['data'] kf = KFold(n_splits=2, shuffle=True, random_state=rng) - with TemporaryDirectory() as tempdir: - model_path = os.path.join(tempdir, 'digits.model') - for train_index, test_index in kf.split(X, y): - xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index]) - xgb_model.save_model(model_path) + for train_index, test_index in kf.split(X, y): + xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index]) + xgb_model.save_model(model_path) + xgb_model = xgb.XGBClassifier() + xgb_model.load_model(model_path) + assert isinstance(xgb_model.classes_, np.ndarray) + assert isinstance(xgb_model._Booster, xgb.Booster) + assert isinstance(xgb_model._le, XGBoostLabelEncoder) + assert isinstance(xgb_model._le.classes_, np.ndarray) + preds = xgb_model.predict(X[test_index]) + labels = y[test_index] + err = sum(1 for i in range(len(preds)) + if int(preds[i] > 0.5) != labels[i]) / float(len(preds)) + assert err < 0.1 + assert xgb_model.get_booster().attr('scikit_learn') is None + + # test native booster + preds = xgb_model.predict(X[test_index], output_margin=True) + booster = xgb.Booster(model_file=model_path) + predt_1 = booster.predict(xgb.DMatrix(X[test_index]), + output_margin=True) + assert np.allclose(preds, predt_1) + + with pytest.raises(TypeError): xgb_model = xgb.XGBModel() xgb_model.load_model(model_path) - preds = xgb_model.predict(X[test_index]) - labels = y[test_index] - err = sum(1 for i in range(len(preds)) - if int(preds[i] > 0.5) != labels[i]) / float(len(preds)) - assert err < 0.1 + + +def test_save_load_model(): + with TemporaryDirectory() as tempdir: + model_path = os.path.join(tempdir, 'digits.model') + save_load_model(model_path) + + with TemporaryDirectory() as tempdir: + model_path = os.path.join(tempdir, 'digits.model.json') + save_load_model(model_path) + + from sklearn.datasets import load_digits + with TemporaryDirectory() as tempdir: + model_path = os.path.join(tempdir, 'digits.model.json') + digits = load_digits(2) + y = digits['target'] + X = digits['data'] + booster = xgb.train({'tree_method': 'hist', + 'objective': 'binary:logistic'}, + dtrain=xgb.DMatrix(X, y), + num_boost_round=4) + predt_0 = booster.predict(xgb.DMatrix(X)) + booster.save_model(model_path) + cls = xgb.XGBClassifier() + cls.load_model(model_path) + predt_1 = cls.predict(X) + assert np.allclose(predt_0, predt_1) + + cls = xgb.XGBModel() + cls.load_model(model_path) + predt_1 = cls.predict(X) + assert np.allclose(predt_0, predt_1) def test_RFECV():