Save and load model in sklearn API (#3192)
* Add (load|save)_model to XGBModel * Add docstring * Fix docstring * Fix mixed use of space and tab * Add a test * Fix Flake8 style errors
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
24fde92660
commit
594bcea83e
@@ -176,7 +176,7 @@ class XGBModel(XGBModelBase):
|
||||
booster : a xgboost booster of underlying model
|
||||
"""
|
||||
if self._Booster is None:
|
||||
raise XGBoostError('need to call fit beforehand')
|
||||
raise XGBoostError('need to call fit or load_model beforehand')
|
||||
return self._Booster
|
||||
|
||||
def get_params(self, deep=False):
|
||||
@@ -214,6 +214,28 @@ class XGBModel(XGBModelBase):
|
||||
xgb_params.pop('nthread', None)
|
||||
return xgb_params
|
||||
|
||||
def save_model(self, fname):
|
||||
"""
|
||||
Save the model to a file.
|
||||
Parameters
|
||||
----------
|
||||
fname : string
|
||||
Output file name
|
||||
"""
|
||||
self.get_booster().save_model(fname)
|
||||
|
||||
def load_model(self, fname):
|
||||
"""
|
||||
Load the model from a file.
|
||||
Parameters
|
||||
----------
|
||||
fname : string or a memory buffer
|
||||
Input file name or memory buffer(see also save_raw)
|
||||
"""
|
||||
if self._Booster is None:
|
||||
self._Booster = Booster({'nthread': self.n_jobs})
|
||||
self._Booster.load_model(fname)
|
||||
|
||||
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
||||
early_stopping_rounds=None, verbose=True, xgb_model=None,
|
||||
sample_weight_eval_set=None):
|
||||
|
||||
Reference in New Issue
Block a user