Save and load model in sklearn API (#3192)

* Add (load|save)_model to XGBModel

* Add docstring

* Fix docstring

* Fix mixed use of space and tab

* Add a test

* Fix Flake8 style errors
This commit is contained in:
Mike Liu 2018-06-30 12:21:49 -07:00 committed by Philip Hyunsu Cho
parent 24fde92660
commit 594bcea83e
2 changed files with 67 additions and 1 deletions

View File

@ -176,7 +176,7 @@ class XGBModel(XGBModelBase):
booster : a xgboost booster of underlying model
"""
if self._Booster is None:
raise XGBoostError('need to call fit beforehand')
raise XGBoostError('need to call fit or load_model beforehand')
return self._Booster
def get_params(self, deep=False):
@ -214,6 +214,28 @@ class XGBModel(XGBModelBase):
xgb_params.pop('nthread', None)
return xgb_params
def save_model(self, fname):
"""
Save the model to a file.
Parameters
----------
fname : string
Output file name
"""
self.get_booster().save_model(fname)
def load_model(self, fname):
"""
Load the model from a file.
Parameters
----------
fname : string or a memory buffer
Input file name or memory buffer(see also save_raw)
"""
if self._Booster is None:
self._Booster = Booster({'nthread': self.n_jobs})
self._Booster.load_model(fname)
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, xgb_model=None,
sample_weight_eval_set=None):

View File

@ -1,11 +1,24 @@
import numpy as np
import xgboost as xgb
import testing as tm
import tempfile
import os
import shutil
from nose.tools import raises
rng = np.random.RandomState(1994)
class TemporaryDirectory(object):
"""Context manager for tempfile.mkdtemp()"""
def __enter__(self):
self.name = tempfile.mkdtemp()
return self.name
def __exit__(self, exc_type, exc_value, traceback):
shutil.rmtree(self.name)
def test_binary_classification():
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
@ -458,3 +471,34 @@ def test_validation_weights_xgbclassifier():
# check that the logloss in the test set is actually different when using weights
# than when not using them
assert all((logloss_with_weights[i] != logloss_without_weights[i] for i in [0, 1]))
def test_save_load_model():
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
try:
from sklearn.model_selection import KFold
except:
from sklearn.cross_validation import KFold
digits = load_digits(2)
y = digits['target']
X = digits['data']
try:
kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng)
except TypeError: # sklearn.model_selection.KFold uses n_split
kf = KFold(
n_splits=2, shuffle=True, random_state=rng
).split(np.arange(y.shape[0]))
with TemporaryDirectory() as tempdir:
model_path = os.path.join(tempdir, 'digits.model')
for train_index, test_index in kf:
xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index])
xgb_model.save_model(model_path)
xgb_model = xgb.XGBModel()
xgb_model.load_model(model_path)
preds = xgb_model.predict(X[test_index])
labels = y[test_index]
err = sum(1 for i in range(len(preds))
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
assert err < 0.1