Save and load model in sklearn API (#3192)

* Add (load|save)_model to XGBModel

* Add docstring

* Fix docstring

* Fix mixed use of space and tab

* Add a test

* Fix Flake8 style errors
This commit is contained in:
Mike Liu
2018-06-30 12:21:49 -07:00
committed by Philip Hyunsu Cho
parent 24fde92660
commit 594bcea83e
2 changed files with 67 additions and 1 deletions

View File

@@ -176,7 +176,7 @@ class XGBModel(XGBModelBase):
booster : a xgboost booster of underlying model
"""
if self._Booster is None:
raise XGBoostError('need to call fit beforehand')
raise XGBoostError('need to call fit or load_model beforehand')
return self._Booster
def get_params(self, deep=False):
@@ -214,6 +214,28 @@ class XGBModel(XGBModelBase):
xgb_params.pop('nthread', None)
return xgb_params
def save_model(self, fname):
"""
Save the model to a file.
Parameters
----------
fname : string
Output file name
"""
self.get_booster().save_model(fname)
def load_model(self, fname):
"""
Load the model from a file.
Parameters
----------
fname : string or a memory buffer
Input file name or memory buffer(see also save_raw)
"""
if self._Booster is None:
self._Booster = Booster({'nthread': self.n_jobs})
self._Booster.load_model(fname)
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, xgb_model=None,
sample_weight_eval_set=None):