Add xgb_model parameter to sklearn fit (#2623)
Adding xgb_model paramter allows the continuation of model training. Model has to be saved by calling `model.get_booster().save_model(path)`
This commit is contained in:
parent
6e378452f2
commit
9a81c74a7b
@ -216,7 +216,7 @@ class XGBModel(XGBModelBase):
|
|||||||
return xgb_params
|
return xgb_params
|
||||||
|
|
||||||
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
||||||
early_stopping_rounds=None, verbose=True):
|
early_stopping_rounds=None, verbose=True, xgb_model=None):
|
||||||
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
|
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
|
||||||
"""
|
"""
|
||||||
Fit the gradient boosting model
|
Fit the gradient boosting model
|
||||||
@ -253,6 +253,9 @@ class XGBModel(XGBModelBase):
|
|||||||
verbose : bool
|
verbose : bool
|
||||||
If `verbose` and an evaluation set is used, writes the evaluation
|
If `verbose` and an evaluation set is used, writes the evaluation
|
||||||
metric measured on the validation set to stderr.
|
metric measured on the validation set to stderr.
|
||||||
|
xgb_model : str
|
||||||
|
file name of stored xgb model or 'Booster' instance Xgb model to be
|
||||||
|
loaded before training (allows training continuation).
|
||||||
"""
|
"""
|
||||||
if sample_weight is not None:
|
if sample_weight is not None:
|
||||||
trainDmatrix = DMatrix(X, label=y, weight=sample_weight,
|
trainDmatrix = DMatrix(X, label=y, weight=sample_weight,
|
||||||
@ -288,7 +291,7 @@ class XGBModel(XGBModelBase):
|
|||||||
self.n_estimators, evals=evals,
|
self.n_estimators, evals=evals,
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
evals_result=evals_result, obj=obj, feval=feval,
|
evals_result=evals_result, obj=obj, feval=feval,
|
||||||
verbose_eval=verbose)
|
verbose_eval=verbose, xgb_model=xgb_model)
|
||||||
|
|
||||||
if evals_result:
|
if evals_result:
|
||||||
for val in evals_result.items():
|
for val in evals_result.items():
|
||||||
@ -406,7 +409,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
random_state, seed, missing, **kwargs)
|
random_state, seed, missing, **kwargs)
|
||||||
|
|
||||||
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
||||||
early_stopping_rounds=None, verbose=True):
|
early_stopping_rounds=None, verbose=True, xgb_model=None):
|
||||||
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||||
"""
|
"""
|
||||||
Fit gradient boosting classifier
|
Fit gradient boosting classifier
|
||||||
@ -443,6 +446,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
verbose : bool
|
verbose : bool
|
||||||
If `verbose` and an evaluation set is used, writes the evaluation
|
If `verbose` and an evaluation set is used, writes the evaluation
|
||||||
metric measured on the validation set to stderr.
|
metric measured on the validation set to stderr.
|
||||||
|
xgb_model : str
|
||||||
|
file name of stored xgb model or 'Booster' instance Xgb model to be
|
||||||
|
loaded before training (allows training continuation).
|
||||||
"""
|
"""
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
self.classes_ = np.unique(y)
|
self.classes_ = np.unique(y)
|
||||||
@ -498,7 +504,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
evals=evals,
|
evals=evals,
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
evals_result=evals_result, obj=obj, feval=feval,
|
evals_result=evals_result, obj=obj, feval=feval,
|
||||||
verbose_eval=verbose)
|
verbose_eval=verbose, xgb_model=None)
|
||||||
|
|
||||||
self.objective = xgb_options["objective"]
|
self.objective = xgb_options["objective"]
|
||||||
if evals_result:
|
if evals_result:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user