Fix early stopping in the Python package (#4638)

* Fix #4630, #4421: Preserve correct ordering between metrics, and always use last metric for early stopping

* Clarify semantics of early stopping in presence of multiple valid sets and metrics

* Add a test

* Fix lint
This commit is contained in:
Philip Hyunsu Cho 2019-07-07 01:01:03 -07:00 committed by GitHub
parent 562d9ae963
commit 1aaf4a679d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 106 additions and 115 deletions

View File

@ -309,8 +309,7 @@ class XGBModel(XGBModelBase):
early_stopping_rounds=None, verbose=True, xgb_model=None, early_stopping_rounds=None, verbose=True, xgb_model=None,
sample_weight_eval_set=None, callbacks=None): sample_weight_eval_set=None, callbacks=None):
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init # pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
""" """Fit gradient boosting model
Fit the gradient boosting model
Parameters Parameters
---------- ----------
@ -321,34 +320,39 @@ class XGBModel(XGBModelBase):
sample_weight : array_like sample_weight : array_like
instance weights instance weights
eval_set : list, optional eval_set : list, optional
A list of (X, y) tuple pairs to use as a validation set for A list of (X, y) tuple pairs to use as validation sets, for which
early-stopping metrics will be computed.
Validation metrics will help us track the performance of the model.
sample_weight_eval_set : list, optional sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
instance weights on the i-th validation set. instance weights on the i-th validation set.
eval_metric : str, callable, optional eval_metric : str, list of str, or callable, optional
If a str, should be a built-in evaluation metric to use. See If a str, should be a built-in evaluation metric to use. See
doc/parameter.rst. If callable, a custom evaluation metric. The call doc/parameter.rst.
signature is func(y_predicted, y_true) where y_true will be a If a list of str, should be the list of multiple built-in evaluation metrics
DMatrix object such that you may need to call the get_label to use.
If callable, a custom evaluation metric. The call
signature is ``func(y_predicted, y_true)`` where ``y_true`` will be a
DMatrix object such that you may need to call the ``get_label``
method. It must return a str, value pair where the str is a name method. It must return a str, value pair where the str is a name
for the evaluation and value is the value of the evaluation for the evaluation and value is the value of the evaluation
function. This objective is always minimized. function. The callable custom objective is always minimized.
early_stopping_rounds : int early_stopping_rounds : int
Activates early stopping. Validation error needs to decrease at Activates early stopping. Validation metric needs to improve at least once in
least every <early_stopping_rounds> round(s) to continue training. every **early_stopping_rounds** round(s) to continue training.
Requires at least one item in evals. If there's more than one, Requires at least one item in **eval_set**.
will use the last. Returns the model from the last iteration The method returns the model from the last iteration (not the best one).
(not the best one). If early stopping occurs, the model will If there's more than one item in **eval_set**, the last entry will be used
have three additional fields: bst.best_score, bst.best_iteration for early stopping.
and bst.best_ntree_limit. If there's more than one metric in **eval_metric**, the last metric will be
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree used for early stopping.
and/or num_class appears in the parameters) If early stopping occurs, the model will have three additional fields:
``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``.
verbose : bool verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation If `verbose` and an evaluation set is used, writes the evaluation
metric measured on the validation set to stderr. metric measured on the validation set to stderr.
xgb_model : str xgb_model : str
file name of stored xgb model or 'Booster' instance Xgb model to be file name of stored XGBoost model or 'Booster' instance XGBoost model to be
loaded before training (allows training continuation). loaded before training (allows training continuation).
callbacks : list of callback functions callbacks : list of callback functions
List of callback functions that are applied at end of each iteration. List of callback functions that are applied at end of each iteration.
@ -629,56 +633,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
early_stopping_rounds=None, verbose=True, xgb_model=None, early_stopping_rounds=None, verbose=True, xgb_model=None,
sample_weight_eval_set=None, callbacks=None): sample_weight_eval_set=None, callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ # pylint: disable = attribute-defined-outside-init,arguments-differ
"""
Fit gradient boosting classifier
Parameters
----------
X : array_like
Feature matrix
y : array_like
Labels
sample_weight : array_like
Weight for each instance
eval_set : list, optional
A list of (X, y) pairs to use as a validation set for
early-stopping
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
instance weights on the i-th validation set.
eval_metric : str, callable, optional
If a str, should be a built-in evaluation metric to use. See
doc/parameter.rst. If callable, a custom evaluation metric. The call
signature is func(y_predicted, y_true) where y_true will be a
DMatrix object such that you may need to call the get_label
method. It must return a str, value pair where the str is a name
for the evaluation and value is the value of the evaluation
function. This objective is always minimized.
early_stopping_rounds : int, optional
Activates early stopping. Validation error needs to decrease at
least every <early_stopping_rounds> round(s) to continue training.
Requires at least one item in evals. If there's more than one,
will use the last. If early stopping occurs, the model will have
three additional fields: bst.best_score, bst.best_iteration and
bst.best_ntree_limit (bst.best_ntree_limit is the ntree_limit parameter
default value in predict method if not any other value is specified).
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree
and/or num_class appears in the parameters)
verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation
metric measured on the validation set to stderr.
xgb_model : str
file name of stored xgb model or 'Booster' instance Xgb model to be
loaded before training (allows training continuation).
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using :ref:`callback_api`.
Example:
.. code-block:: python
[xgb.callback.reset_learning_rate(custom_rates)]
"""
evals_result = {} evals_result = {}
self.classes_ = np.unique(y) self.classes_ = np.unique(y)
self.n_classes_ = len(self.classes_) self.n_classes_ = len(self.classes_)
@ -751,6 +706,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
return self return self
fit.__doc__ = XGBModel.fit.__doc__.replace('Fit gradient boosting model',
'Fit gradient boosting classifier', 1)
def predict(self, data, output_margin=False, ntree_limit=None, validate_features=True): def predict(self, data, output_margin=False, ntree_limit=None, validate_features=True):
""" """
Predict with `data`. Predict with `data`.
@ -1027,14 +985,15 @@ class XGBRanker(XGBModel):
Note Note
---- ----
A custom objective function is currently not supported by XGBRanker. A custom objective function is currently not supported by XGBRanker.
Likewise, a custom metric function is not supported either.
Note Note
---- ----
Group information is required for ranking tasks. Query group information is required for ranking tasks.
Before fitting the model, your data need to be sorted by group. When Before fitting the model, your data need to be sorted by query group. When
fitting the model, you need to provide an additional array that fitting the model, you need to provide an additional array that
contains the size of each group. contains the size of each query group.
For example, if your original data look like: For example, if your original data look like:
@ -1086,7 +1045,7 @@ class XGBRanker(XGBModel):
verbose=False, xgb_model=None, callbacks=None): verbose=False, xgb_model=None, callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ # pylint: disable = attribute-defined-outside-init,arguments-differ
""" """
Fit the gradient boosting model Fit gradient boosting ranker
Parameters Parameters
---------- ----------
@ -1095,57 +1054,57 @@ class XGBRanker(XGBModel):
y : array_like y : array_like
Labels Labels
group : array_like group : array_like
group size of training data Size of each query group of training data. Should have as many elements as
the query groups in the training data
sample_weight : array_like sample_weight : array_like
group weights Query group weights
.. note:: Weights are per-group for ranking tasks .. note:: Weights are per-group for ranking tasks
In ranking task, one weight is assigned to each group (not each data In ranking task, one weight is assigned to each query group (not each
point). This is because we only care about the relative ordering of data point). This is because we only care about the relative ordering of
data points within each group, so it doesn't make sense to assign data points within each group, so it doesn't make sense to assign
weights to individual data points. weights to individual data points.
eval_set : list, optional eval_set : list, optional
A list of (X, y) tuple pairs to use as a validation set for A list of (X, y) tuple pairs to use as validation sets, for which
early-stopping metrics will be computed.
Validation metrics will help us track the performance of the model.
sample_weight_eval_set : list, optional sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
group weights on the i-th validation set. group weights on the i-th validation set.
.. note:: Weights are per-group for ranking tasks .. note:: Weights are per-group for ranking tasks
In ranking task, one weight is assigned to each group (not each data In ranking task, one weight is assigned to each query group (not each
point). This is because we only care about the relative ordering of data point). This is because we only care about the relative ordering of
data points within each group, so it doesn't make sense to assign data points within each group, so it doesn't make sense to assign
weights to individual data points. weights to individual data points.
eval_group : list of arrays, optional eval_group : list of arrays, optional
A list that contains the group size corresponds to each A list in which ``eval_group[i]`` is the list containing the sizes of all
(X, y) pair in eval_set query groups in the ``i``-th pair in **eval_set**.
eval_metric : str, callable, optional eval_metric : str, list of str, optional
If a str, should be a built-in evaluation metric to use. See If a str, should be a built-in evaluation metric to use. See
doc/parameter.rst. If callable, a custom evaluation metric. The call doc/parameter.rst.
signature is func(y_predicted, y_true) where y_true will be a If a list of str, should be the list of multiple built-in evaluation metrics
DMatrix object such that you may need to call the get_label to use. The custom evaluation metric is not yet supported for the ranker.
method. It must return a str, value pair where the str is a name
for the evaluation and value is the value of the evaluation
function. This objective is always minimized.
early_stopping_rounds : int early_stopping_rounds : int
Activates early stopping. Validation error needs to decrease at Activates early stopping. Validation metric needs to improve at least once in
least every <early_stopping_rounds> round(s) to continue training. every **early_stopping_rounds** round(s) to continue training.
Requires at least one item in evals. If there's more than one, Requires at least one item in **eval_set**.
will use the last. Returns the model from the last iteration The method returns the model from the last iteration (not the best one).
(not the best one). If early stopping occurs, the model will If there's more than one item in **eval_set**, the last entry will be used
have three additional fields: bst.best_score, bst.best_iteration for early stopping.
and bst.best_ntree_limit. If there's more than one metric in **eval_metric**, the last metric will be
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree used for early stopping.
and/or num_class appears in the parameters) If early stopping occurs, the model will have three additional fields:
``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``.
verbose : bool verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation If `verbose` and an evaluation set is used, writes the evaluation
metric measured on the validation set to stderr. metric measured on the validation set to stderr.
xgb_model : str xgb_model : str
file name of stored xgb model or 'Booster' instance Xgb model to be file name of stored XGBoost model or 'Booster' instance XGBoost model to be
loaded before training (allows training continuation). loaded before training (allows training continuation).
callbacks : list of callback functions callbacks : list of callback functions
List of callback functions that are applied at end of each iteration. List of callback functions that are applied at end of each iteration.
@ -1199,8 +1158,8 @@ class XGBRanker(XGBModel):
feval = eval_metric if callable(eval_metric) else None feval = eval_metric if callable(eval_metric) else None
if eval_metric is not None: if eval_metric is not None:
if callable(eval_metric): if callable(eval_metric):
eval_metric = None raise ValueError('Custom evaluation metric is not yet supported' +
else: 'for XGBRanker.')
params.update({'eval_metric': eval_metric}) params.update({'eval_metric': eval_metric})
self._Booster = train(params, train_dmatrix, self._Booster = train(params, train_dmatrix,

View File

@ -127,8 +127,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
num_boost_round: int num_boost_round: int
Number of boosting iterations. Number of boosting iterations.
evals: list of pairs (DMatrix, string) evals: list of pairs (DMatrix, string)
List of items to be evaluated during training, this allows user to watch List of validation sets for which metrics will evaluated during training.
performance on the validation set. Validation metrics will help us track the performance of the model.
obj : function obj : function
Customized objective function. Customized objective function.
feval : function feval : function
@ -136,11 +136,14 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
maximize : bool maximize : bool
Whether to maximize feval. Whether to maximize feval.
early_stopping_rounds: int early_stopping_rounds: int
Activates early stopping. Validation error needs to decrease at least Activates early stopping. Validation metric needs to improve at least once in
every **early_stopping_rounds** round(s) to continue training. every **early_stopping_rounds** round(s) to continue training.
Requires at least one item in **evals**. Requires at least one item in **evals**.
If there's more than one, will use the last. The method returns the model from the last iteration (not the best one).
Returns the model from the last iteration (not the best one). If there's more than one item in **evals**, the last entry will be used
for early stopping.
If there's more than one metric in the **eval_metric** parameter given in
**params**, the last metric will be used for early stopping.
If early stopping occurs, the model will have three additional fields: If early stopping occurs, the model will have three additional fields:
``bst.best_score``, ``bst.best_iteration`` and ``bst.best_ntree_limit``. ``bst.best_score``, ``bst.best_iteration`` and ``bst.best_ntree_limit``.
(Use ``bst.best_ntree_limit`` to get the correct value if (Use ``bst.best_ntree_limit`` to get the correct value if
@ -352,16 +355,16 @@ def aggcv(rlist):
for line in rlist: for line in rlist:
arr = line.split() arr = line.split()
assert idx == arr[0] assert idx == arr[0]
for it in arr[1:]: for metric_idx, it in enumerate(arr[1:]):
if not isinstance(it, STRING_TYPES): if not isinstance(it, STRING_TYPES):
it = it.decode() it = it.decode()
k, v = it.split(':') k, v = it.split(':')
if k not in cvmap: if (metric_idx, k) not in cvmap:
cvmap[k] = [] cvmap[(metric_idx, k)] = []
cvmap[k].append(float(v)) cvmap[(metric_idx, k)].append(float(v))
msg = idx msg = idx
results = [] results = []
for k, v in sorted(cvmap.items(), key=lambda x: (x[0].startswith('test'), x[0])): for (metric_idx, k), v in sorted(cvmap.items(), key=lambda x: x[0][0]):
v = np.array(v) v = np.array(v)
if not isinstance(msg, STRING_TYPES): if not isinstance(msg, STRING_TYPES):
msg = msg.decode() msg = msg.decode()
@ -405,9 +408,12 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
maximize : bool maximize : bool
Whether to maximize feval. Whether to maximize feval.
early_stopping_rounds: int early_stopping_rounds: int
Activates early stopping. CV error needs to decrease at least Activates early stopping. Cross-Validation metric (average of validation
every <early_stopping_rounds> round(s) to continue. metric computed over CV folds) needs to improve at least once in
Last entry in evaluation history is the one from best iteration. every **early_stopping_rounds** round(s) to continue training.
The last entry in the evaluation history will represent the best iteration.
If there's more than one metric in the **eval_metric** parameter given in
**params**, the last metric will be used for early stopping.
fpreproc : function fpreproc : function
Preprocessing function that takes (dtrain, dtest, param) and returns Preprocessing function that takes (dtrain, dtest, param) and returns
transformed versions of those. transformed versions of those.

View File

@ -80,3 +80,29 @@ class TestEarlyStopping(unittest.TestCase):
feval=self.evalerror, maximize=True, feval=self.evalerror, maximize=True,
early_stopping_rounds=1) early_stopping_rounds=1)
self.assert_metrics_length(cv, 1) self.assert_metrics_length(cv, 1)
@pytest.mark.skipif(**tm.no_sklearn())
def test_cv_early_stopping_with_multiple_eval_sets_and_metrics(self):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
dm = xgb.DMatrix(X, label=y)
params = {'objective':'binary:logistic'}
metrics = [['auc'], ['error'], ['logloss'],
['logloss', 'auc'], ['logloss', 'error'], ['error', 'logloss']]
num_iteration_history = []
# If more than one metrics is given, early stopping should use the last metric
for i, m in enumerate(metrics):
result = xgb.cv(params, dm, num_boost_round=1000, nfold=5, stratified=True,
metrics=m, early_stopping_rounds=20, seed=42)
num_iteration_history.append(len(result))
df = result['test-{}-mean'.format(m[-1])]
# When early stopping is invoked, the last metric should be as best it can be.
if m[-1] == 'auc':
assert np.all(df <= df.iloc[-1])
else:
assert np.all(df >= df.iloc[-1])
assert num_iteration_history[:3] == num_iteration_history[3:]