From 594bcea83e7b4a861ff738e38b946ec5821223c6 Mon Sep 17 00:00:00 2001 From: Mike Liu Date: Sat, 30 Jun 2018 12:21:49 -0700 Subject: [PATCH] Save and load model in sklearn API (#3192) * Add (load|save)_model to XGBModel * Add docstring * Fix docstring * Fix mixed use of space and tab * Add a test * Fix Flake8 style errors --- python-package/xgboost/sklearn.py | 24 ++++++++++++++++- tests/python/test_with_sklearn.py | 44 +++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index eeda07d54..66eff2abd 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -176,7 +176,7 @@ class XGBModel(XGBModelBase): booster : a xgboost booster of underlying model """ if self._Booster is None: - raise XGBoostError('need to call fit beforehand') + raise XGBoostError('need to call fit or load_model beforehand') return self._Booster def get_params(self, deep=False): @@ -214,6 +214,28 @@ class XGBModel(XGBModelBase): xgb_params.pop('nthread', None) return xgb_params + def save_model(self, fname): + """ + Save the model to a file. + Parameters + ---------- + fname : string + Output file name + """ + self.get_booster().save_model(fname) + + def load_model(self, fname): + """ + Load the model from a file. + Parameters + ---------- + fname : string or a memory buffer + Input file name or memory buffer(see also save_raw) + """ + if self._Booster is None: + self._Booster = Booster({'nthread': self.n_jobs}) + self._Booster.load_model(fname) + def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None, early_stopping_rounds=None, verbose=True, xgb_model=None, sample_weight_eval_set=None): diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 6fc2eaecb..b184b3952 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -1,11 +1,24 @@ import numpy as np import xgboost as xgb import testing as tm +import tempfile +import os +import shutil from nose.tools import raises rng = np.random.RandomState(1994) +class TemporaryDirectory(object): + """Context manager for tempfile.mkdtemp()""" + def __enter__(self): + self.name = tempfile.mkdtemp() + return self.name + + def __exit__(self, exc_type, exc_value, traceback): + shutil.rmtree(self.name) + + def test_binary_classification(): tm._skip_if_no_sklearn() from sklearn.datasets import load_digits @@ -458,3 +471,34 @@ def test_validation_weights_xgbclassifier(): # check that the logloss in the test set is actually different when using weights # than when not using them assert all((logloss_with_weights[i] != logloss_without_weights[i] for i in [0, 1])) + + +def test_save_load_model(): + tm._skip_if_no_sklearn() + from sklearn.datasets import load_digits + try: + from sklearn.model_selection import KFold + except: + from sklearn.cross_validation import KFold + + digits = load_digits(2) + y = digits['target'] + X = digits['data'] + try: + kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng) + except TypeError: # sklearn.model_selection.KFold uses n_split + kf = KFold( + n_splits=2, shuffle=True, random_state=rng + ).split(np.arange(y.shape[0])) + with TemporaryDirectory() as tempdir: + model_path = os.path.join(tempdir, 'digits.model') + for train_index, test_index in kf: + xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index]) + xgb_model.save_model(model_path) + xgb_model = xgb.XGBModel() + xgb_model.load_model(model_path) + preds = xgb_model.predict(X[test_index]) + labels = y[test_index] + err = sum(1 for i in range(len(preds)) + if int(preds[i] > 0.5) != labels[i]) / float(len(preds)) + assert err < 0.1