Fix #3663: Allow sklearn API to use callbacks (#3682)

* Fix #3663: Allow sklearn API to use callbacks

* Fix lint

* Add Callback API to Python API doc
This commit is contained in:
mrgutkun 2018-09-07 23:51:26 +03:00 committed by Philip Hyunsu Cho
parent 5a8bbb39a1
commit 4b43810f51
4 changed files with 104 additions and 53 deletions

View File

@ -53,3 +53,15 @@ Plotting API
.. autofunction:: xgboost.plot_tree .. autofunction:: xgboost.plot_tree
.. autofunction:: xgboost.to_graphviz .. autofunction:: xgboost.to_graphviz
.. _callback_api:
Callback API
------------
.. autofunction:: xgboost.callback.print_evaluation
.. autofunction:: xgboost.callback.record_evaluation
.. autofunction:: xgboost.callback.reset_learning_rate
.. autofunction:: xgboost.callback.early_stop

View File

@ -32,7 +32,7 @@ def _fmt_metric(value, show_stdv=True):
def print_evaluation(period=1, show_stdv=True): def print_evaluation(period=1, show_stdv=True):
"""Create a callback that print evaluation result. """Create a callback that print evaluation result.
We print the evaluation results every ``period`` iterations We print the evaluation results every **period** iterations
and on the first and the last iterations. and on the first and the last iterations.
Parameters Parameters
@ -60,7 +60,7 @@ def print_evaluation(period=1, show_stdv=True):
def record_evaluation(eval_result): def record_evaluation(eval_result):
"""Create a call back that records the evaluation history into eval_result. """Create a call back that records the evaluation history into **eval_result**.
Parameters Parameters
---------- ----------
@ -109,10 +109,11 @@ def reset_learning_rate(learning_rates):
learning_rates: list or function learning_rates: list or function
List of learning rate for each boosting round List of learning rate for each boosting round
or a customized function that calculates eta in terms of or a customized function that calculates eta in terms of
current number of round and the total number of boosting round (e.g. yields current number of round and the total number of boosting round (e.g.
learning rate decay) yields learning rate decay)
- list l: eta = l[boosting_round]
- function f: eta = f(boosting_round, num_boost_round) * list ``l``: ``eta = l[boosting_round]``
* function ``f``: ``eta = f(boosting_round, num_boost_round)``
Returns Returns
------- -------
@ -150,14 +151,14 @@ def early_stop(stopping_rounds, maximize=False, verbose=True):
"""Create a callback that activates early stoppping. """Create a callback that activates early stoppping.
Validation error needs to decrease at least Validation error needs to decrease at least
every <stopping_rounds> round(s) to continue training. every **stopping_rounds** round(s) to continue training.
Requires at least one item in evals. Requires at least one item in **evals**.
If there's more than one, will use the last. If there's more than one, will use the last.
Returns the model from the last iteration (not the best one). Returns the model from the last iteration (not the best one).
If early stopping occurs, the model will have three additional fields: If early stopping occurs, the model will have three additional fields:
bst.best_score, bst.best_iteration and bst.best_ntree_limit. ``bst.best_score``, ``bst.best_iteration`` and ``bst.best_ntree_limit``.
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree (Use ``bst.best_ntree_limit`` to get the correct value if ``num_parallel_tree``
and/or num_class appears in the parameters) and/or ``num_class`` appears in the parameters)
Parameters Parameters
---------- ----------

View File

@ -1,5 +1,5 @@
# coding: utf-8 # coding: utf-8
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, E0012, R0912 # pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, E0012, R0912, C0302
"""Scikit-Learn Wrapper interface for XGBoost.""" """Scikit-Learn Wrapper interface for XGBoost."""
from __future__ import absolute_import from __future__ import absolute_import
@ -69,9 +69,9 @@ class XGBModel(XGBModelBase):
booster: string booster: string
Specify which booster to use: gbtree, gblinear or dart. Specify which booster to use: gbtree, gblinear or dart.
nthread : int nthread : int
Number of parallel threads used to run xgboost. (Deprecated, please use n_jobs) Number of parallel threads used to run xgboost. (Deprecated, please use ``n_jobs``)
n_jobs : int n_jobs : int
Number of parallel threads used to run xgboost. (replaces nthread) Number of parallel threads used to run xgboost. (replaces ``nthread``)
gamma : float gamma : float
Minimum loss reduction required to make a further partition on a leaf node of the tree. Minimum loss reduction required to make a further partition on a leaf node of the tree.
min_child_weight : int min_child_weight : int
@ -242,7 +242,7 @@ class XGBModel(XGBModelBase):
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None, def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, xgb_model=None, early_stopping_rounds=None, verbose=True, xgb_model=None,
sample_weight_eval_set=None): sample_weight_eval_set=None, callbacks=None):
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init # pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
""" """
Fit the gradient boosting model Fit the gradient boosting model
@ -285,6 +285,14 @@ class XGBModel(XGBModelBase):
xgb_model : str xgb_model : str
file name of stored xgb model or 'Booster' instance Xgb model to be file name of stored xgb model or 'Booster' instance Xgb model to be
loaded before training (allows training continuation). loaded before training (allows training continuation).
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using :ref:`callback_api`.
Example:
.. code-block:: python
[xgb.callback.reset_learning_rate(custom_rates)]
""" """
if sample_weight is not None: if sample_weight is not None:
trainDmatrix = DMatrix(X, label=y, weight=sample_weight, trainDmatrix = DMatrix(X, label=y, weight=sample_weight,
@ -325,7 +333,8 @@ class XGBModel(XGBModelBase):
self.n_estimators, evals=evals, self.n_estimators, evals=evals,
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, obj=obj, feval=feval, evals_result=evals_result, obj=obj, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model) verbose_eval=verbose, xgb_model=xgb_model,
callbacks=callbacks)
if evals_result: if evals_result:
for val in evals_result.items(): for val in evals_result.items():
@ -413,10 +422,10 @@ class XGBModel(XGBModelBase):
def evals_result(self): def evals_result(self):
"""Return the evaluation results. """Return the evaluation results.
If ``eval_set`` is passed to the `fit` function, you can call ``evals_result()`` to If **eval_set** is passed to the `fit` function, you can call
get evaluation results for all passed eval_sets. When ``eval_metric`` is also ``evals_result()`` to get evaluation results for all passed **eval_sets**.
passed to the ``fit`` function, the ``evals_result`` will contain the ``eval_metrics`` When **eval_metric** is also passed to the `fit` function, the
passed to the ``fit`` function **evals_result** will contain the **eval_metrics** passed to the `fit` function.
Returns Returns
------- -------
@ -438,9 +447,9 @@ class XGBModel(XGBModelBase):
evals_result = clf.evals_result() evals_result = clf.evals_result()
The variable evals_result will contain: The variable **evals_result** will contain:
.. code-block:: none .. code-block:: python
{'validation_0': {'logloss': ['0.604835', '0.531479']}, {'validation_0': {'logloss': ['0.604835', '0.531479']},
'validation_1': {'logloss': ['0.41965', '0.17686']}} 'validation_1': {'logloss': ['0.41965', '0.17686']}}
@ -492,7 +501,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None, def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, xgb_model=None, early_stopping_rounds=None, verbose=True, xgb_model=None,
sample_weight_eval_set=None): sample_weight_eval_set=None, callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ # pylint: disable = attribute-defined-outside-init,arguments-differ
""" """
Fit gradient boosting classifier Fit gradient boosting classifier
@ -535,6 +544,14 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
xgb_model : str xgb_model : str
file name of stored xgb model or 'Booster' instance Xgb model to be file name of stored xgb model or 'Booster' instance Xgb model to be
loaded before training (allows training continuation). loaded before training (allows training continuation).
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using :ref:`callback_api`.
Example:
.. code-block:: python
[xgb.callback.reset_learning_rate(custom_rates)]
""" """
evals_result = {} evals_result = {}
self.classes_ = np.unique(y) self.classes_ = np.unique(y)
@ -592,7 +609,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
evals=evals, evals=evals,
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, obj=obj, feval=feval, evals_result=evals_result, obj=obj, feval=feval,
verbose_eval=verbose, xgb_model=None) verbose_eval=verbose, xgb_model=None,
callbacks=callbacks)
self.objective = xgb_options["objective"] self.objective = xgb_options["objective"]
if evals_result: if evals_result:
@ -705,10 +723,10 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
def evals_result(self): def evals_result(self):
"""Return the evaluation results. """Return the evaluation results.
If eval_set is passed to the `fit` function, you can call evals_result() to If **eval_set** is passed to the `fit` function, you can call
get evaluation results for all passed eval_sets. When eval_metric is also ``evals_result()`` to get evaluation results for all passed **eval_sets**.
passed to the `fit` function, the evals_result will contain the eval_metrics When **eval_metric** is also passed to the `fit` function, the
passed to the `fit` function **evals_result** will contain the **eval_metrics** passed to the `fit` function.
Returns Returns
------- -------
@ -730,9 +748,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
evals_result = clf.evals_result() evals_result = clf.evals_result()
The variable ``evals_result`` will contain The variable **evals_result** will contain
.. code-block:: none .. code-block:: python
{'validation_0': {'logloss': ['0.604835', '0.531479']}, {'validation_0': {'logloss': ['0.604835', '0.531479']},
'validation_1': {'logloss': ['0.41965', '0.17686']}} 'validation_1': {'logloss': ['0.41965', '0.17686']}}
@ -771,9 +789,9 @@ class XGBRanker(XGBModel):
booster: string booster: string
Specify which booster to use: gbtree, gblinear or dart. Specify which booster to use: gbtree, gblinear or dart.
nthread : int nthread : int
Number of parallel threads used to run xgboost. (Deprecated, please use n_jobs) Number of parallel threads used to run xgboost. (Deprecated, please use ``n_jobs``)
n_jobs : int n_jobs : int
Number of parallel threads used to run xgboost. (replaces nthread) Number of parallel threads used to run xgboost. (replaces ``nthread``)
gamma : float gamma : float
Minimum loss reduction required to make a further partition on a leaf node of the tree. Minimum loss reduction required to make a further partition on a leaf node of the tree.
min_child_weight : int min_child_weight : int
@ -816,8 +834,12 @@ class XGBRanker(XGBModel):
---- ----
A custom objective function is currently not supported by XGBRanker. A custom objective function is currently not supported by XGBRanker.
Group information is required for ranking tasks. Before fitting the model, your data need to Note
be sorted by group. When fitting the model, you need to provide an additional array that ----
Group information is required for ranking tasks.
Before fitting the model, your data need to be sorted by group. When
fitting the model, you need to provide an additional array that
contains the size of each group. contains the size of each group.
For example, if your original data look like: For example, if your original data look like:
@ -863,7 +885,7 @@ class XGBRanker(XGBModel):
def fit(self, X, y, group, sample_weight=None, eval_set=None, sample_weight_eval_set=None, def fit(self, X, y, group, sample_weight=None, eval_set=None, sample_weight_eval_set=None,
eval_group=None, eval_metric=None, early_stopping_rounds=None, eval_group=None, eval_metric=None, early_stopping_rounds=None,
verbose=False, xgb_model=None): verbose=False, xgb_model=None, callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ # pylint: disable = attribute-defined-outside-init,arguments-differ
""" """
Fit the gradient boosting model Fit the gradient boosting model
@ -911,6 +933,14 @@ class XGBRanker(XGBModel):
xgb_model : str xgb_model : str
file name of stored xgb model or 'Booster' instance Xgb model to be file name of stored xgb model or 'Booster' instance Xgb model to be
loaded before training (allows training continuation). loaded before training (allows training continuation).
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using :ref:`callback_api`.
Example:
.. code-block:: python
[xgb.callback.reset_learning_rate(custom_rates)]
""" """
# check if group information is provided # check if group information is provided
if group is None: if group is None:
@ -963,7 +993,8 @@ class XGBRanker(XGBModel):
self.n_estimators, self.n_estimators,
early_stopping_rounds=early_stopping_rounds, evals=evals, early_stopping_rounds=early_stopping_rounds, evals=evals,
evals_result=evals_result, feval=feval, evals_result=evals_result, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model) verbose_eval=verbose, xgb_model=xgb_model,
callbacks=callbacks)
self.objective = params["objective"] self.objective = params["objective"]

View File

@ -137,34 +137,35 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
Whether to maximize feval. Whether to maximize feval.
early_stopping_rounds: int early_stopping_rounds: int
Activates early stopping. Validation error needs to decrease at least Activates early stopping. Validation error needs to decrease at least
every <early_stopping_rounds> round(s) to continue training. every **early_stopping_rounds** round(s) to continue training.
Requires at least one item in evals. Requires at least one item in **evals**.
If there's more than one, will use the last. If there's more than one, will use the last.
Returns the model from the last iteration (not the best one). Returns the model from the last iteration (not the best one).
If early stopping occurs, the model will have three additional fields: If early stopping occurs, the model will have three additional fields:
bst.best_score, bst.best_iteration and bst.best_ntree_limit. ``bst.best_score``, ``bst.best_iteration`` and ``bst.best_ntree_limit``.
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree (Use ``bst.best_ntree_limit`` to get the correct value if
and/or num_class appears in the parameters) ``num_parallel_tree`` and/or ``num_class`` appears in the parameters)
evals_result: dict evals_result: dict
This dictionary stores the evaluation results of all the items in watchlist. This dictionary stores the evaluation results of all the items in watchlist.
Example: with a watchlist containing [(dtest,'eval'), (dtrain,'train')] and Example: with a watchlist containing
a parameter containing ('eval_metric': 'logloss'), the **evals_result** ``[(dtest,'eval'), (dtrain,'train')]`` and
returns a parameter containing ``('eval_metric': 'logloss')``,
the **evals_result** returns
.. code-block:: none .. code-block:: python
{'train': {'logloss': ['0.48253', '0.35953']}, {'train': {'logloss': ['0.48253', '0.35953']},
'eval': {'logloss': ['0.480385', '0.357756']}} 'eval': {'logloss': ['0.480385', '0.357756']}}
verbose_eval : bool or int verbose_eval : bool or int
Requires at least one item in evals. Requires at least one item in **evals**.
If **verbose_eval** is True then the evaluation metric on the validation set is If **verbose_eval** is True then the evaluation metric on the validation set is
printed at each boosting stage. printed at each boosting stage.
If **verbose_eval** is an integer then the evaluation metric on the validation set If **verbose_eval** is an integer then the evaluation metric on the validation set
is printed at every given **verbose_eval** boosting stage. The last boosting stage is printed at every given **verbose_eval** boosting stage. The last boosting stage
/ the boosting stage found by using **early_stopping_rounds** is also printed. / the boosting stage found by using **early_stopping_rounds** is also printed.
Example: with ``verbose_eval=4`` and at least one item in evals, an evaluation metric Example: with ``verbose_eval=4`` and at least one item in **evals**, an evaluation metric
is printed every 4 boosting stages, instead of every boosting stage. is printed every 4 boosting stages, instead of every boosting stage.
learning_rates: list or function (deprecated - use callback API instead) learning_rates: list or function (deprecated - use callback API instead)
List of learning rate for each boosting round List of learning rate for each boosting round
@ -175,12 +176,17 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
Xgb model to be loaded before training (allows training continuation). Xgb model to be loaded before training (allows training continuation).
callbacks : list of callback functions callbacks : list of callback functions
List of callback functions that are applied at end of each iteration. List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using xgb.callback module. It is possible to use predefined callbacks by using
Example: [xgb.callback.reset_learning_rate(custom_rates)] :ref:`Callback API <callback_api>`.
Example:
.. code-block:: python
[xgb.callback.reset_learning_rate(custom_rates)]
Returns Returns
------- -------
booster : a trained booster model Booster : a trained booster model
""" """
callbacks = [] if callbacks is None else callbacks callbacks = [] if callbacks is None else callbacks
@ -334,7 +340,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
folds : a KFold or StratifiedKFold instance or list of fold indices folds : a KFold or StratifiedKFold instance or list of fold indices
Sklearn KFolds or StratifiedKFolds object. Sklearn KFolds or StratifiedKFolds object.
Alternatively may explicitly pass sample indices for each fold. Alternatively may explicitly pass sample indices for each fold.
For ``n`` folds, ``folds`` should be a length ``n`` list of tuples. For ``n`` folds, **folds** should be a length ``n`` list of tuples.
Each tuple is ``(in,out)`` where ``in`` is a list of indices to be used Each tuple is ``(in,out)`` where ``in`` is a list of indices to be used
as the training samples for the ``n`` th fold and ``out`` is a list of as the training samples for the ``n`` th fold and ``out`` is a list of
indices to be used as the testing samples for the ``n`` th fold. indices to be used as the testing samples for the ``n`` th fold.
@ -368,10 +374,11 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
Seed used to generate the folds (passed to numpy.random.seed). Seed used to generate the folds (passed to numpy.random.seed).
callbacks : list of callback functions callbacks : list of callback functions
List of callback functions that are applied at end of each iteration. List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using xgb.callback module. It is possible to use predefined callbacks by using
:ref:`Callback API <callback_api>`.
Example: Example:
.. code-block:: none .. code-block:: python
[xgb.callback.reset_learning_rate(custom_rates)] [xgb.callback.reset_learning_rate(custom_rates)]
shuffle : bool shuffle : bool