ENH: Allow early stopping in sklearn API.
This commit is contained in:
parent
167544d792
commit
0f5f9c0385
@ -772,7 +772,6 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
-------
|
||||
booster : a trained booster model
|
||||
"""
|
||||
|
||||
evals = list(evals)
|
||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||
|
||||
@ -1074,6 +1073,8 @@ class XGBModel(XGBModelBase):
|
||||
params = super(XGBModel, self).get_params(deep=deep)
|
||||
if params['missing'] is np.nan:
|
||||
params['missing'] = None # sklearn doesn't handle nan. see #4725
|
||||
if not params.get('eval_metric', True):
|
||||
del params['eval_metric'] # don't give as None param to Booster
|
||||
return params
|
||||
|
||||
def get_xgb_params(self):
|
||||
@ -1086,10 +1087,62 @@ class XGBModel(XGBModelBase):
|
||||
xgb_params.pop('nthread', None)
|
||||
return xgb_params
|
||||
|
||||
def fit(self, data, y):
|
||||
def fit(self, X, y, eval_set=None, eval_metric=None,
|
||||
early_stopping_rounds=None, feval=None):
|
||||
# pylint: disable=missing-docstring,invalid-name
|
||||
train_dmatrix = DMatrix(data, label=y, missing=self.missing)
|
||||
self._Booster = train(self.get_xgb_params(), train_dmatrix, self.n_estimators)
|
||||
"""
|
||||
Fit the gradient boosting model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array_like
|
||||
Feature matrix
|
||||
y : array_like
|
||||
Labels
|
||||
eval_set : list, optional
|
||||
A list of (X, y) tuple pairs to use as a validation set for
|
||||
early-stopping
|
||||
eval_metric : str, optional
|
||||
Built-in evaluation metric to use.
|
||||
early_stopping_rounds : int
|
||||
Activates early stopping. Validation error needs to decrease at
|
||||
least every <early_stopping_rounds> round(s) to continue training.
|
||||
Requires at least one item in evals. If there's more than one,
|
||||
will use the last. Returns the model from the last iteration
|
||||
(not the best one). If early stopping occurs, the model will
|
||||
have two additional fields: bst.best_score and bst.best_iteration.
|
||||
feval : function, optional
|
||||
Custom evaluation metric to use. The call signature is
|
||||
feval(y_predicted, y_true) where y_true will be a DMatrix object
|
||||
such that you may need to call the get_label method. This objective
|
||||
if always assumed to be minimized, so use -feval when appropriate.
|
||||
"""
|
||||
trainDmatrix = DMatrix(X, label=y, missing=self.missing)
|
||||
|
||||
eval_results = {}
|
||||
if eval_set is not None:
|
||||
evals = list(DMatrix(x[0], label=x[1]) for x in eval_set)
|
||||
evals = list(zip(evals,
|
||||
["validation_{}" for i in range(len(evals))]))
|
||||
else:
|
||||
evals = ()
|
||||
|
||||
params = self.get_xgb_params()
|
||||
|
||||
if eval_metric is not None:
|
||||
params.update({'eval_metric': eval_metric})
|
||||
|
||||
self._Booster = train(params, trainDmatrix,
|
||||
self.n_estimators, evals=evals,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
evals_result=eval_results, feval=None)
|
||||
if eval_results:
|
||||
eval_results = {k: np.array(v, dtype=float)
|
||||
for k, v in eval_results.items()}
|
||||
eval_results = {k: np.array(v) for k, v in eval_results.items()}
|
||||
self.eval_results_ = eval_results
|
||||
self.best_score_ = self._Booster.best_score
|
||||
self.best_iteration_ = self._Booster.best_iteration
|
||||
return self
|
||||
|
||||
def predict(self, data):
|
||||
@ -1117,8 +1170,39 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
colsample_bytree,
|
||||
base_score, seed, missing)
|
||||
|
||||
def fit(self, X, y, sample_weight=None):
|
||||
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
||||
early_stopping_rounds=None, feval=None):
|
||||
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||
"""
|
||||
Fit gradient boosting classifier
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array_like
|
||||
Feature matrix
|
||||
y : array_like
|
||||
Labels
|
||||
sample_weight : array_like
|
||||
Weight for each instance
|
||||
eval_set : list, optional
|
||||
A list of (X, y) pairs to use as a validation set for
|
||||
early-stopping
|
||||
eval_metric : str
|
||||
Built-in evaluation metric to use.
|
||||
early_stopping_rounds : int, optional
|
||||
Activates early stopping. Validation error needs to decrease at
|
||||
least every <early_stopping_rounds> round(s) to continue training.
|
||||
Requires at least one item in evals. If there's more than one,
|
||||
will use the last. Returns the model from the last iteration
|
||||
(not the best one). If early stopping occurs, the model will
|
||||
have two additional fields: bst.best_score and bst.best_iteration.
|
||||
feval : function, optional
|
||||
Custom evaluation metric to use. The call signature is
|
||||
feval(y_predicted, y_true) where y_true will be a DMatrix object
|
||||
such that you may need to call the get_label method. This objective
|
||||
if always assumed to be minimized, so use -feval when appropriate.
|
||||
"""
|
||||
eval_results = {}
|
||||
self.classes_ = list(np.unique(y))
|
||||
self.n_classes_ = len(self.classes_)
|
||||
if self.n_classes_ > 2:
|
||||
@ -1129,6 +1213,18 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
else:
|
||||
xgb_options = self.get_xgb_params()
|
||||
|
||||
if eval_metric is not None:
|
||||
xgb_options.update({"eval_metric": eval_metric})
|
||||
|
||||
if eval_set is not None:
|
||||
# TODO: use sample_weight if given?
|
||||
evals = list(DMatrix(x[0], label=x[1]) for x in eval_set)
|
||||
nevals = len(evals)
|
||||
eval_names = ["validation_{}".format(i) for i in range(nevals)]
|
||||
evals = list(zip(evals, eval_names))
|
||||
else:
|
||||
evals = ()
|
||||
|
||||
self._le = LabelEncoder().fit(y)
|
||||
training_labels = self._le.transform(y)
|
||||
|
||||
@ -1139,7 +1235,17 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
train_dmatrix = DMatrix(X, label=training_labels,
|
||||
missing=self.missing)
|
||||
|
||||
self._Booster = train(xgb_options, train_dmatrix, self.n_estimators)
|
||||
self._Booster = train(xgb_options, train_dmatrix, self.n_estimators,
|
||||
evals=evals,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
evals_result=eval_results, feval=feval)
|
||||
|
||||
if eval_results:
|
||||
eval_results = {k: np.array(v, dtype=float)
|
||||
for k, v in eval_results.items()}
|
||||
self.eval_results_ = eval_results
|
||||
self.best_score_ = self._Booster.best_score
|
||||
self.best_iteration_ = self._Booster.best_iteration
|
||||
|
||||
return self
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user