Rework Python callback functions. (#6199)
* Define a new callback interface for Python. * Deprecate the old callbacks. * Enable early stopping on dask.
This commit is contained in:
59
doc/python/callbacks.rst
Normal file
59
doc/python/callbacks.rst
Normal file
@@ -0,0 +1,59 @@
|
||||
##################
|
||||
Callback Functions
|
||||
##################
|
||||
|
||||
This document gives a basic walkthrough of callback function used in XGBoost Python
|
||||
package. In XGBoost 1.3, a new callback interface is designed for Python package, which
|
||||
provides the flexiblity of designing various extension for training. Also, XGBoost has a
|
||||
number of pre-defined callbacks for supporting early stopping, checkpoints etc.
|
||||
|
||||
#######################
|
||||
Using builtin callbacks
|
||||
#######################
|
||||
|
||||
By default, training methods in XGBoost have parameters like ``early_stopping_rounds`` and
|
||||
``verbose``/``verbose_eval``, when specified the training procedure will define the
|
||||
corresponding callbacks internally. For example, when ``early_stopping_rounds`` is
|
||||
specified, ``EarlyStopping`` callback is invoked inside iteration loop. You can also pass
|
||||
this callback function directly into XGBoost:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
D_train = xgb.DMatrix(X_train, y_train)
|
||||
D_valid = xgb.DMatrix(X_valid, y_valid)
|
||||
|
||||
# Define a custom evaluation metric used for early stopping.
|
||||
def eval_error_metric(predt, dtrain: xgb.DMatrix):
|
||||
label = dtrain.get_label()
|
||||
r = np.zeros(predt.shape)
|
||||
gt = predt > 0.5
|
||||
r[gt] = 1 - label[gt]
|
||||
le = predt <= 0.5
|
||||
r[le] = label[le]
|
||||
return 'CustomErr', np.sum(r)
|
||||
|
||||
# Specify which dataset and which metric should be used for early stopping.
|
||||
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
||||
metric_name='CustomErr',
|
||||
data_name='Train')
|
||||
|
||||
booster = xgb.train(
|
||||
{'objective': 'binary:logistic',
|
||||
'eval_metric': ['error', 'rmse'],
|
||||
'tree_method': 'hist'}, D_train,
|
||||
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
|
||||
feval=eval_error_metric,
|
||||
num_boost_round=1000,
|
||||
callbacks=[early_stop],
|
||||
verbose_eval=False)
|
||||
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(early_stop.stopping_history['Valid']['CustomErr']) == len(dump)
|
||||
|
||||
##########################
|
||||
Defining your own callback
|
||||
##########################
|
||||
|
||||
XGBoost provides an callback interface class: ``xgboost.callback.TrainingCallback``, user
|
||||
defined callbacks should inherit this class and override corresponding methods. There's a
|
||||
working example in `demo/guide-python/callbacks.py <https://github.com/dmlc/xgboost/tree/master/demo/guide-python/callbacks.py>`_
|
||||
@@ -11,4 +11,5 @@ Contents
|
||||
.. toctree::
|
||||
python_intro
|
||||
python_api
|
||||
callbacks
|
||||
Python examples <https://github.com/dmlc/xgboost/tree/master/demo/guide-python>
|
||||
|
||||
Reference in New Issue
Block a user