Fix early stopping in the Python package (#4638)
* Fix #4630, #4421: Preserve correct ordering between metrics, and always use last metric for early stopping * Clarify semantics of early stopping in presence of multiple valid sets and metrics * Add a test * Fix lint
This commit is contained in:
parent
562d9ae963
commit
1aaf4a679d
@ -309,8 +309,7 @@ class XGBModel(XGBModelBase):
|
||||
early_stopping_rounds=None, verbose=True, xgb_model=None,
|
||||
sample_weight_eval_set=None, callbacks=None):
|
||||
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
|
||||
"""
|
||||
Fit the gradient boosting model
|
||||
"""Fit gradient boosting model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -321,34 +320,39 @@ class XGBModel(XGBModelBase):
|
||||
sample_weight : array_like
|
||||
instance weights
|
||||
eval_set : list, optional
|
||||
A list of (X, y) tuple pairs to use as a validation set for
|
||||
early-stopping
|
||||
A list of (X, y) tuple pairs to use as validation sets, for which
|
||||
metrics will be computed.
|
||||
Validation metrics will help us track the performance of the model.
|
||||
sample_weight_eval_set : list, optional
|
||||
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
|
||||
instance weights on the i-th validation set.
|
||||
eval_metric : str, callable, optional
|
||||
eval_metric : str, list of str, or callable, optional
|
||||
If a str, should be a built-in evaluation metric to use. See
|
||||
doc/parameter.rst. If callable, a custom evaluation metric. The call
|
||||
signature is func(y_predicted, y_true) where y_true will be a
|
||||
DMatrix object such that you may need to call the get_label
|
||||
doc/parameter.rst.
|
||||
If a list of str, should be the list of multiple built-in evaluation metrics
|
||||
to use.
|
||||
If callable, a custom evaluation metric. The call
|
||||
signature is ``func(y_predicted, y_true)`` where ``y_true`` will be a
|
||||
DMatrix object such that you may need to call the ``get_label``
|
||||
method. It must return a str, value pair where the str is a name
|
||||
for the evaluation and value is the value of the evaluation
|
||||
function. This objective is always minimized.
|
||||
function. The callable custom objective is always minimized.
|
||||
early_stopping_rounds : int
|
||||
Activates early stopping. Validation error needs to decrease at
|
||||
least every <early_stopping_rounds> round(s) to continue training.
|
||||
Requires at least one item in evals. If there's more than one,
|
||||
will use the last. Returns the model from the last iteration
|
||||
(not the best one). If early stopping occurs, the model will
|
||||
have three additional fields: bst.best_score, bst.best_iteration
|
||||
and bst.best_ntree_limit.
|
||||
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree
|
||||
and/or num_class appears in the parameters)
|
||||
Activates early stopping. Validation metric needs to improve at least once in
|
||||
every **early_stopping_rounds** round(s) to continue training.
|
||||
Requires at least one item in **eval_set**.
|
||||
The method returns the model from the last iteration (not the best one).
|
||||
If there's more than one item in **eval_set**, the last entry will be used
|
||||
for early stopping.
|
||||
If there's more than one metric in **eval_metric**, the last metric will be
|
||||
used for early stopping.
|
||||
If early stopping occurs, the model will have three additional fields:
|
||||
``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``.
|
||||
verbose : bool
|
||||
If `verbose` and an evaluation set is used, writes the evaluation
|
||||
metric measured on the validation set to stderr.
|
||||
xgb_model : str
|
||||
file name of stored xgb model or 'Booster' instance Xgb model to be
|
||||
file name of stored XGBoost model or 'Booster' instance XGBoost model to be
|
||||
loaded before training (allows training continuation).
|
||||
callbacks : list of callback functions
|
||||
List of callback functions that are applied at end of each iteration.
|
||||
@ -629,56 +633,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
early_stopping_rounds=None, verbose=True, xgb_model=None,
|
||||
sample_weight_eval_set=None, callbacks=None):
|
||||
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||
"""
|
||||
Fit gradient boosting classifier
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array_like
|
||||
Feature matrix
|
||||
y : array_like
|
||||
Labels
|
||||
sample_weight : array_like
|
||||
Weight for each instance
|
||||
eval_set : list, optional
|
||||
A list of (X, y) pairs to use as a validation set for
|
||||
early-stopping
|
||||
sample_weight_eval_set : list, optional
|
||||
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
|
||||
instance weights on the i-th validation set.
|
||||
eval_metric : str, callable, optional
|
||||
If a str, should be a built-in evaluation metric to use. See
|
||||
doc/parameter.rst. If callable, a custom evaluation metric. The call
|
||||
signature is func(y_predicted, y_true) where y_true will be a
|
||||
DMatrix object such that you may need to call the get_label
|
||||
method. It must return a str, value pair where the str is a name
|
||||
for the evaluation and value is the value of the evaluation
|
||||
function. This objective is always minimized.
|
||||
early_stopping_rounds : int, optional
|
||||
Activates early stopping. Validation error needs to decrease at
|
||||
least every <early_stopping_rounds> round(s) to continue training.
|
||||
Requires at least one item in evals. If there's more than one,
|
||||
will use the last. If early stopping occurs, the model will have
|
||||
three additional fields: bst.best_score, bst.best_iteration and
|
||||
bst.best_ntree_limit (bst.best_ntree_limit is the ntree_limit parameter
|
||||
default value in predict method if not any other value is specified).
|
||||
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree
|
||||
and/or num_class appears in the parameters)
|
||||
verbose : bool
|
||||
If `verbose` and an evaluation set is used, writes the evaluation
|
||||
metric measured on the validation set to stderr.
|
||||
xgb_model : str
|
||||
file name of stored xgb model or 'Booster' instance Xgb model to be
|
||||
loaded before training (allows training continuation).
|
||||
callbacks : list of callback functions
|
||||
List of callback functions that are applied at end of each iteration.
|
||||
It is possible to use predefined callbacks by using :ref:`callback_api`.
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
[xgb.callback.reset_learning_rate(custom_rates)]
|
||||
"""
|
||||
evals_result = {}
|
||||
self.classes_ = np.unique(y)
|
||||
self.n_classes_ = len(self.classes_)
|
||||
@ -751,6 +706,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
|
||||
return self
|
||||
|
||||
fit.__doc__ = XGBModel.fit.__doc__.replace('Fit gradient boosting model',
|
||||
'Fit gradient boosting classifier', 1)
|
||||
|
||||
def predict(self, data, output_margin=False, ntree_limit=None, validate_features=True):
|
||||
"""
|
||||
Predict with `data`.
|
||||
@ -1027,14 +985,15 @@ class XGBRanker(XGBModel):
|
||||
Note
|
||||
----
|
||||
A custom objective function is currently not supported by XGBRanker.
|
||||
Likewise, a custom metric function is not supported either.
|
||||
|
||||
Note
|
||||
----
|
||||
Group information is required for ranking tasks.
|
||||
Query group information is required for ranking tasks.
|
||||
|
||||
Before fitting the model, your data need to be sorted by group. When
|
||||
Before fitting the model, your data need to be sorted by query group. When
|
||||
fitting the model, you need to provide an additional array that
|
||||
contains the size of each group.
|
||||
contains the size of each query group.
|
||||
|
||||
For example, if your original data look like:
|
||||
|
||||
@ -1086,7 +1045,7 @@ class XGBRanker(XGBModel):
|
||||
verbose=False, xgb_model=None, callbacks=None):
|
||||
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||
"""
|
||||
Fit the gradient boosting model
|
||||
Fit gradient boosting ranker
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -1095,57 +1054,57 @@ class XGBRanker(XGBModel):
|
||||
y : array_like
|
||||
Labels
|
||||
group : array_like
|
||||
group size of training data
|
||||
Size of each query group of training data. Should have as many elements as
|
||||
the query groups in the training data
|
||||
sample_weight : array_like
|
||||
group weights
|
||||
Query group weights
|
||||
|
||||
.. note:: Weights are per-group for ranking tasks
|
||||
|
||||
In ranking task, one weight is assigned to each group (not each data
|
||||
point). This is because we only care about the relative ordering of
|
||||
In ranking task, one weight is assigned to each query group (not each
|
||||
data point). This is because we only care about the relative ordering of
|
||||
data points within each group, so it doesn't make sense to assign
|
||||
weights to individual data points.
|
||||
|
||||
eval_set : list, optional
|
||||
A list of (X, y) tuple pairs to use as a validation set for
|
||||
early-stopping
|
||||
A list of (X, y) tuple pairs to use as validation sets, for which
|
||||
metrics will be computed.
|
||||
Validation metrics will help us track the performance of the model.
|
||||
sample_weight_eval_set : list, optional
|
||||
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
|
||||
group weights on the i-th validation set.
|
||||
|
||||
.. note:: Weights are per-group for ranking tasks
|
||||
|
||||
In ranking task, one weight is assigned to each group (not each data
|
||||
point). This is because we only care about the relative ordering of
|
||||
In ranking task, one weight is assigned to each query group (not each
|
||||
data point). This is because we only care about the relative ordering of
|
||||
data points within each group, so it doesn't make sense to assign
|
||||
weights to individual data points.
|
||||
|
||||
eval_group : list of arrays, optional
|
||||
A list that contains the group size corresponds to each
|
||||
(X, y) pair in eval_set
|
||||
eval_metric : str, callable, optional
|
||||
A list in which ``eval_group[i]`` is the list containing the sizes of all
|
||||
query groups in the ``i``-th pair in **eval_set**.
|
||||
eval_metric : str, list of str, optional
|
||||
If a str, should be a built-in evaluation metric to use. See
|
||||
doc/parameter.rst. If callable, a custom evaluation metric. The call
|
||||
signature is func(y_predicted, y_true) where y_true will be a
|
||||
DMatrix object such that you may need to call the get_label
|
||||
method. It must return a str, value pair where the str is a name
|
||||
for the evaluation and value is the value of the evaluation
|
||||
function. This objective is always minimized.
|
||||
doc/parameter.rst.
|
||||
If a list of str, should be the list of multiple built-in evaluation metrics
|
||||
to use. The custom evaluation metric is not yet supported for the ranker.
|
||||
early_stopping_rounds : int
|
||||
Activates early stopping. Validation error needs to decrease at
|
||||
least every <early_stopping_rounds> round(s) to continue training.
|
||||
Requires at least one item in evals. If there's more than one,
|
||||
will use the last. Returns the model from the last iteration
|
||||
(not the best one). If early stopping occurs, the model will
|
||||
have three additional fields: bst.best_score, bst.best_iteration
|
||||
and bst.best_ntree_limit.
|
||||
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree
|
||||
and/or num_class appears in the parameters)
|
||||
Activates early stopping. Validation metric needs to improve at least once in
|
||||
every **early_stopping_rounds** round(s) to continue training.
|
||||
Requires at least one item in **eval_set**.
|
||||
The method returns the model from the last iteration (not the best one).
|
||||
If there's more than one item in **eval_set**, the last entry will be used
|
||||
for early stopping.
|
||||
If there's more than one metric in **eval_metric**, the last metric will be
|
||||
used for early stopping.
|
||||
If early stopping occurs, the model will have three additional fields:
|
||||
``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``.
|
||||
verbose : bool
|
||||
If `verbose` and an evaluation set is used, writes the evaluation
|
||||
metric measured on the validation set to stderr.
|
||||
xgb_model : str
|
||||
file name of stored xgb model or 'Booster' instance Xgb model to be
|
||||
file name of stored XGBoost model or 'Booster' instance XGBoost model to be
|
||||
loaded before training (allows training continuation).
|
||||
callbacks : list of callback functions
|
||||
List of callback functions that are applied at end of each iteration.
|
||||
@ -1199,8 +1158,8 @@ class XGBRanker(XGBModel):
|
||||
feval = eval_metric if callable(eval_metric) else None
|
||||
if eval_metric is not None:
|
||||
if callable(eval_metric):
|
||||
eval_metric = None
|
||||
else:
|
||||
raise ValueError('Custom evaluation metric is not yet supported' +
|
||||
'for XGBRanker.')
|
||||
params.update({'eval_metric': eval_metric})
|
||||
|
||||
self._Booster = train(params, train_dmatrix,
|
||||
|
||||
@ -127,8 +127,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
num_boost_round: int
|
||||
Number of boosting iterations.
|
||||
evals: list of pairs (DMatrix, string)
|
||||
List of items to be evaluated during training, this allows user to watch
|
||||
performance on the validation set.
|
||||
List of validation sets for which metrics will evaluated during training.
|
||||
Validation metrics will help us track the performance of the model.
|
||||
obj : function
|
||||
Customized objective function.
|
||||
feval : function
|
||||
@ -136,11 +136,14 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
maximize : bool
|
||||
Whether to maximize feval.
|
||||
early_stopping_rounds: int
|
||||
Activates early stopping. Validation error needs to decrease at least
|
||||
Activates early stopping. Validation metric needs to improve at least once in
|
||||
every **early_stopping_rounds** round(s) to continue training.
|
||||
Requires at least one item in **evals**.
|
||||
If there's more than one, will use the last.
|
||||
Returns the model from the last iteration (not the best one).
|
||||
The method returns the model from the last iteration (not the best one).
|
||||
If there's more than one item in **evals**, the last entry will be used
|
||||
for early stopping.
|
||||
If there's more than one metric in the **eval_metric** parameter given in
|
||||
**params**, the last metric will be used for early stopping.
|
||||
If early stopping occurs, the model will have three additional fields:
|
||||
``bst.best_score``, ``bst.best_iteration`` and ``bst.best_ntree_limit``.
|
||||
(Use ``bst.best_ntree_limit`` to get the correct value if
|
||||
@ -352,16 +355,16 @@ def aggcv(rlist):
|
||||
for line in rlist:
|
||||
arr = line.split()
|
||||
assert idx == arr[0]
|
||||
for it in arr[1:]:
|
||||
for metric_idx, it in enumerate(arr[1:]):
|
||||
if not isinstance(it, STRING_TYPES):
|
||||
it = it.decode()
|
||||
k, v = it.split(':')
|
||||
if k not in cvmap:
|
||||
cvmap[k] = []
|
||||
cvmap[k].append(float(v))
|
||||
if (metric_idx, k) not in cvmap:
|
||||
cvmap[(metric_idx, k)] = []
|
||||
cvmap[(metric_idx, k)].append(float(v))
|
||||
msg = idx
|
||||
results = []
|
||||
for k, v in sorted(cvmap.items(), key=lambda x: (x[0].startswith('test'), x[0])):
|
||||
for (metric_idx, k), v in sorted(cvmap.items(), key=lambda x: x[0][0]):
|
||||
v = np.array(v)
|
||||
if not isinstance(msg, STRING_TYPES):
|
||||
msg = msg.decode()
|
||||
@ -405,9 +408,12 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
||||
maximize : bool
|
||||
Whether to maximize feval.
|
||||
early_stopping_rounds: int
|
||||
Activates early stopping. CV error needs to decrease at least
|
||||
every <early_stopping_rounds> round(s) to continue.
|
||||
Last entry in evaluation history is the one from best iteration.
|
||||
Activates early stopping. Cross-Validation metric (average of validation
|
||||
metric computed over CV folds) needs to improve at least once in
|
||||
every **early_stopping_rounds** round(s) to continue training.
|
||||
The last entry in the evaluation history will represent the best iteration.
|
||||
If there's more than one metric in the **eval_metric** parameter given in
|
||||
**params**, the last metric will be used for early stopping.
|
||||
fpreproc : function
|
||||
Preprocessing function that takes (dtrain, dtest, param) and returns
|
||||
transformed versions of those.
|
||||
|
||||
@ -80,3 +80,29 @@ class TestEarlyStopping(unittest.TestCase):
|
||||
feval=self.evalerror, maximize=True,
|
||||
early_stopping_rounds=1)
|
||||
self.assert_metrics_length(cv, 1)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_cv_early_stopping_with_multiple_eval_sets_and_metrics(self):
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
dm = xgb.DMatrix(X, label=y)
|
||||
params = {'objective':'binary:logistic'}
|
||||
|
||||
metrics = [['auc'], ['error'], ['logloss'],
|
||||
['logloss', 'auc'], ['logloss', 'error'], ['error', 'logloss']]
|
||||
|
||||
num_iteration_history = []
|
||||
|
||||
# If more than one metrics is given, early stopping should use the last metric
|
||||
for i, m in enumerate(metrics):
|
||||
result = xgb.cv(params, dm, num_boost_round=1000, nfold=5, stratified=True,
|
||||
metrics=m, early_stopping_rounds=20, seed=42)
|
||||
num_iteration_history.append(len(result))
|
||||
df = result['test-{}-mean'.format(m[-1])]
|
||||
# When early stopping is invoked, the last metric should be as best it can be.
|
||||
if m[-1] == 'auc':
|
||||
assert np.all(df <= df.iloc[-1])
|
||||
else:
|
||||
assert np.all(df >= df.iloc[-1])
|
||||
assert num_iteration_history[:3] == num_iteration_history[3:]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user