Fix early stopping in the Python package (#4638)

* Fix #4630, #4421: Preserve correct ordering between metrics, and always use last metric for early stopping

* Clarify semantics of early stopping in presence of multiple valid sets and metrics

* Add a test

* Fix lint
This commit is contained in:
Philip Hyunsu Cho 2019-07-07 01:01:03 -07:00 committed by GitHub
parent 562d9ae963
commit 1aaf4a679d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 106 additions and 115 deletions

View File

@ -309,8 +309,7 @@ class XGBModel(XGBModelBase):
early_stopping_rounds=None, verbose=True, xgb_model=None,
sample_weight_eval_set=None, callbacks=None):
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
"""
Fit the gradient boosting model
"""Fit gradient boosting model
Parameters
----------
@ -321,34 +320,39 @@ class XGBModel(XGBModelBase):
sample_weight : array_like
instance weights
eval_set : list, optional
A list of (X, y) tuple pairs to use as a validation set for
early-stopping
A list of (X, y) tuple pairs to use as validation sets, for which
metrics will be computed.
Validation metrics will help us track the performance of the model.
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
instance weights on the i-th validation set.
eval_metric : str, callable, optional
eval_metric : str, list of str, or callable, optional
If a str, should be a built-in evaluation metric to use. See
doc/parameter.rst. If callable, a custom evaluation metric. The call
signature is func(y_predicted, y_true) where y_true will be a
DMatrix object such that you may need to call the get_label
doc/parameter.rst.
If a list of str, should be the list of multiple built-in evaluation metrics
to use.
If callable, a custom evaluation metric. The call
signature is ``func(y_predicted, y_true)`` where ``y_true`` will be a
DMatrix object such that you may need to call the ``get_label``
method. It must return a str, value pair where the str is a name
for the evaluation and value is the value of the evaluation
function. This objective is always minimized.
function. The callable custom objective is always minimized.
early_stopping_rounds : int
Activates early stopping. Validation error needs to decrease at
least every <early_stopping_rounds> round(s) to continue training.
Requires at least one item in evals. If there's more than one,
will use the last. Returns the model from the last iteration
(not the best one). If early stopping occurs, the model will
have three additional fields: bst.best_score, bst.best_iteration
and bst.best_ntree_limit.
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree
and/or num_class appears in the parameters)
Activates early stopping. Validation metric needs to improve at least once in
every **early_stopping_rounds** round(s) to continue training.
Requires at least one item in **eval_set**.
The method returns the model from the last iteration (not the best one).
If there's more than one item in **eval_set**, the last entry will be used
for early stopping.
If there's more than one metric in **eval_metric**, the last metric will be
used for early stopping.
If early stopping occurs, the model will have three additional fields:
``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``.
verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation
metric measured on the validation set to stderr.
xgb_model : str
file name of stored xgb model or 'Booster' instance Xgb model to be
file name of stored XGBoost model or 'Booster' instance XGBoost model to be
loaded before training (allows training continuation).
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
@ -629,56 +633,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
early_stopping_rounds=None, verbose=True, xgb_model=None,
sample_weight_eval_set=None, callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ
"""
Fit gradient boosting classifier
Parameters
----------
X : array_like
Feature matrix
y : array_like
Labels
sample_weight : array_like
Weight for each instance
eval_set : list, optional
A list of (X, y) pairs to use as a validation set for
early-stopping
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
instance weights on the i-th validation set.
eval_metric : str, callable, optional
If a str, should be a built-in evaluation metric to use. See
doc/parameter.rst. If callable, a custom evaluation metric. The call
signature is func(y_predicted, y_true) where y_true will be a
DMatrix object such that you may need to call the get_label
method. It must return a str, value pair where the str is a name
for the evaluation and value is the value of the evaluation
function. This objective is always minimized.
early_stopping_rounds : int, optional
Activates early stopping. Validation error needs to decrease at
least every <early_stopping_rounds> round(s) to continue training.
Requires at least one item in evals. If there's more than one,
will use the last. If early stopping occurs, the model will have
three additional fields: bst.best_score, bst.best_iteration and
bst.best_ntree_limit (bst.best_ntree_limit is the ntree_limit parameter
default value in predict method if not any other value is specified).
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree
and/or num_class appears in the parameters)
verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation
metric measured on the validation set to stderr.
xgb_model : str
file name of stored xgb model or 'Booster' instance Xgb model to be
loaded before training (allows training continuation).
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using :ref:`callback_api`.
Example:
.. code-block:: python
[xgb.callback.reset_learning_rate(custom_rates)]
"""
evals_result = {}
self.classes_ = np.unique(y)
self.n_classes_ = len(self.classes_)
@ -751,6 +706,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
return self
fit.__doc__ = XGBModel.fit.__doc__.replace('Fit gradient boosting model',
'Fit gradient boosting classifier', 1)
def predict(self, data, output_margin=False, ntree_limit=None, validate_features=True):
"""
Predict with `data`.
@ -1027,14 +985,15 @@ class XGBRanker(XGBModel):
Note
----
A custom objective function is currently not supported by XGBRanker.
Likewise, a custom metric function is not supported either.
Note
----
Group information is required for ranking tasks.
Query group information is required for ranking tasks.
Before fitting the model, your data need to be sorted by group. When
Before fitting the model, your data need to be sorted by query group. When
fitting the model, you need to provide an additional array that
contains the size of each group.
contains the size of each query group.
For example, if your original data look like:
@ -1086,7 +1045,7 @@ class XGBRanker(XGBModel):
verbose=False, xgb_model=None, callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ
"""
Fit the gradient boosting model
Fit gradient boosting ranker
Parameters
----------
@ -1095,57 +1054,57 @@ class XGBRanker(XGBModel):
y : array_like
Labels
group : array_like
group size of training data
Size of each query group of training data. Should have as many elements as
the query groups in the training data
sample_weight : array_like
group weights
Query group weights
.. note:: Weights are per-group for ranking tasks
In ranking task, one weight is assigned to each group (not each data
point). This is because we only care about the relative ordering of
In ranking task, one weight is assigned to each query group (not each
data point). This is because we only care about the relative ordering of
data points within each group, so it doesn't make sense to assign
weights to individual data points.
eval_set : list, optional
A list of (X, y) tuple pairs to use as a validation set for
early-stopping
A list of (X, y) tuple pairs to use as validation sets, for which
metrics will be computed.
Validation metrics will help us track the performance of the model.
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
group weights on the i-th validation set.
.. note:: Weights are per-group for ranking tasks
In ranking task, one weight is assigned to each group (not each data
point). This is because we only care about the relative ordering of
In ranking task, one weight is assigned to each query group (not each
data point). This is because we only care about the relative ordering of
data points within each group, so it doesn't make sense to assign
weights to individual data points.
eval_group : list of arrays, optional
A list that contains the group size corresponds to each
(X, y) pair in eval_set
eval_metric : str, callable, optional
A list in which ``eval_group[i]`` is the list containing the sizes of all
query groups in the ``i``-th pair in **eval_set**.
eval_metric : str, list of str, optional
If a str, should be a built-in evaluation metric to use. See
doc/parameter.rst. If callable, a custom evaluation metric. The call
signature is func(y_predicted, y_true) where y_true will be a
DMatrix object such that you may need to call the get_label
method. It must return a str, value pair where the str is a name
for the evaluation and value is the value of the evaluation
function. This objective is always minimized.
doc/parameter.rst.
If a list of str, should be the list of multiple built-in evaluation metrics
to use. The custom evaluation metric is not yet supported for the ranker.
early_stopping_rounds : int
Activates early stopping. Validation error needs to decrease at
least every <early_stopping_rounds> round(s) to continue training.
Requires at least one item in evals. If there's more than one,
will use the last. Returns the model from the last iteration
(not the best one). If early stopping occurs, the model will
have three additional fields: bst.best_score, bst.best_iteration
and bst.best_ntree_limit.
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree
and/or num_class appears in the parameters)
Activates early stopping. Validation metric needs to improve at least once in
every **early_stopping_rounds** round(s) to continue training.
Requires at least one item in **eval_set**.
The method returns the model from the last iteration (not the best one).
If there's more than one item in **eval_set**, the last entry will be used
for early stopping.
If there's more than one metric in **eval_metric**, the last metric will be
used for early stopping.
If early stopping occurs, the model will have three additional fields:
``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``.
verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation
metric measured on the validation set to stderr.
xgb_model : str
file name of stored xgb model or 'Booster' instance Xgb model to be
file name of stored XGBoost model or 'Booster' instance XGBoost model to be
loaded before training (allows training continuation).
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
@ -1199,8 +1158,8 @@ class XGBRanker(XGBModel):
feval = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
eval_metric = None
else:
raise ValueError('Custom evaluation metric is not yet supported' +
'for XGBRanker.')
params.update({'eval_metric': eval_metric})
self._Booster = train(params, train_dmatrix,

View File

@ -127,8 +127,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
num_boost_round: int
Number of boosting iterations.
evals: list of pairs (DMatrix, string)
List of items to be evaluated during training, this allows user to watch
performance on the validation set.
List of validation sets for which metrics will evaluated during training.
Validation metrics will help us track the performance of the model.
obj : function
Customized objective function.
feval : function
@ -136,11 +136,14 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
maximize : bool
Whether to maximize feval.
early_stopping_rounds: int
Activates early stopping. Validation error needs to decrease at least
Activates early stopping. Validation metric needs to improve at least once in
every **early_stopping_rounds** round(s) to continue training.
Requires at least one item in **evals**.
If there's more than one, will use the last.
Returns the model from the last iteration (not the best one).
The method returns the model from the last iteration (not the best one).
If there's more than one item in **evals**, the last entry will be used
for early stopping.
If there's more than one metric in the **eval_metric** parameter given in
**params**, the last metric will be used for early stopping.
If early stopping occurs, the model will have three additional fields:
``bst.best_score``, ``bst.best_iteration`` and ``bst.best_ntree_limit``.
(Use ``bst.best_ntree_limit`` to get the correct value if
@ -352,16 +355,16 @@ def aggcv(rlist):
for line in rlist:
arr = line.split()
assert idx == arr[0]
for it in arr[1:]:
for metric_idx, it in enumerate(arr[1:]):
if not isinstance(it, STRING_TYPES):
it = it.decode()
k, v = it.split(':')
if k not in cvmap:
cvmap[k] = []
cvmap[k].append(float(v))
if (metric_idx, k) not in cvmap:
cvmap[(metric_idx, k)] = []
cvmap[(metric_idx, k)].append(float(v))
msg = idx
results = []
for k, v in sorted(cvmap.items(), key=lambda x: (x[0].startswith('test'), x[0])):
for (metric_idx, k), v in sorted(cvmap.items(), key=lambda x: x[0][0]):
v = np.array(v)
if not isinstance(msg, STRING_TYPES):
msg = msg.decode()
@ -405,9 +408,12 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
maximize : bool
Whether to maximize feval.
early_stopping_rounds: int
Activates early stopping. CV error needs to decrease at least
every <early_stopping_rounds> round(s) to continue.
Last entry in evaluation history is the one from best iteration.
Activates early stopping. Cross-Validation metric (average of validation
metric computed over CV folds) needs to improve at least once in
every **early_stopping_rounds** round(s) to continue training.
The last entry in the evaluation history will represent the best iteration.
If there's more than one metric in the **eval_metric** parameter given in
**params**, the last metric will be used for early stopping.
fpreproc : function
Preprocessing function that takes (dtrain, dtest, param) and returns
transformed versions of those.

View File

@ -80,3 +80,29 @@ class TestEarlyStopping(unittest.TestCase):
feval=self.evalerror, maximize=True,
early_stopping_rounds=1)
self.assert_metrics_length(cv, 1)
@pytest.mark.skipif(**tm.no_sklearn())
def test_cv_early_stopping_with_multiple_eval_sets_and_metrics(self):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
dm = xgb.DMatrix(X, label=y)
params = {'objective':'binary:logistic'}
metrics = [['auc'], ['error'], ['logloss'],
['logloss', 'auc'], ['logloss', 'error'], ['error', 'logloss']]
num_iteration_history = []
# If more than one metrics is given, early stopping should use the last metric
for i, m in enumerate(metrics):
result = xgb.cv(params, dm, num_boost_round=1000, nfold=5, stratified=True,
metrics=m, early_stopping_rounds=20, seed=42)
num_iteration_history.append(len(result))
df = result['test-{}-mean'.format(m[-1])]
# When early stopping is invoked, the last metric should be as best it can be.
if m[-1] == 'auc':
assert np.all(df <= df.iloc[-1])
else:
assert np.all(df >= df.iloc[-1])
assert num_iteration_history[:3] == num_iteration_history[3:]