xgboost/doc/python/callbacks.rst
Jiaming Yuan 2cc9662005
Support slicing tree model (#6302)
This PR is meant the end the confusion around best_ntree_limit and unify model slicing. We have multi-class and random forests, asking users to understand how to set ntree_limit is difficult and error prone.

* Implement the save_best option in early stopping.

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
2020-11-02 23:27:39 -08:00

60 lines
2.4 KiB
ReStructuredText

##################
Callback Functions
##################
This document gives a basic walkthrough of callback function used in XGBoost Python
package. In XGBoost 1.3, a new callback interface is designed for Python package, which
provides the flexiblity of designing various extension for training. Also, XGBoost has a
number of pre-defined callbacks for supporting early stopping, checkpoints etc.
Using builtin callbacks
-----------------------
By default, training methods in XGBoost have parameters like ``early_stopping_rounds`` and
``verbose``/``verbose_eval``, when specified the training procedure will define the
corresponding callbacks internally. For example, when ``early_stopping_rounds`` is
specified, ``EarlyStopping`` callback is invoked inside iteration loop. You can also pass
this callback function directly into XGBoost:
.. code-block:: python
D_train = xgb.DMatrix(X_train, y_train)
D_valid = xgb.DMatrix(X_valid, y_valid)
# Define a custom evaluation metric used for early stopping.
def eval_error_metric(predt, dtrain: xgb.DMatrix):
label = dtrain.get_label()
r = np.zeros(predt.shape)
gt = predt > 0.5
r[gt] = 1 - label[gt]
le = predt <= 0.5
r[le] = label[le]
return 'CustomErr', np.sum(r)
# Specify which dataset and which metric should be used for early stopping.
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
metric_name='CustomErr',
data_name='Train')
booster = xgb.train(
{'objective': 'binary:logistic',
'eval_metric': ['error', 'rmse'],
'tree_method': 'hist'}, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
feval=eval_error_metric,
num_boost_round=1000,
callbacks=[early_stop],
verbose_eval=False)
dump = booster.get_dump(dump_format='json')
assert len(early_stop.stopping_history['Valid']['CustomErr']) == len(dump)
Defining your own callback
--------------------------
XGBoost provides an callback interface class: ``xgboost.callback.TrainingCallback``, user
defined callbacks should inherit this class and override corresponding methods. There's a
working example in `demo/guide-python/callbacks.py <https://github.com/dmlc/xgboost/tree/master/demo/guide-python/callbacks.py>`_