Document for process_type. (#7135)

* Update document for prune and refresh.

* Add demo.
This commit is contained in:
Jiaming Yuan 2021-08-03 13:11:52 +08:00 committed by GitHub
parent d080b5a953
commit 7bdedacb54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 115 additions and 4 deletions

View File

@ -0,0 +1,90 @@
"""Demo for using `process_type` with `prune` and `refresh`. Modifying existing trees is
not a well established use for XGBoost, so feel free to experiment.
"""
import xgboost as xgb
from sklearn.datasets import load_boston
import numpy as np
def main():
n_rounds = 32
X, y = load_boston(return_X_y=True)
# Train a model first
X_train = X[: X.shape[0] // 2]
y_train = y[: y.shape[0] // 2]
Xy = xgb.DMatrix(X_train, y_train)
evals_result: xgb.callback.EvaluationMonitor.EvalsLog = {}
booster = xgb.train(
{"tree_method": "gpu_hist", "max_depth": 6},
Xy,
num_boost_round=n_rounds,
evals=[(Xy, "Train")],
evals_result=evals_result,
)
SHAP = booster.predict(Xy, pred_contribs=True)
# Refresh the leaf value and tree statistic
X_refresh = X[X.shape[0] // 2:]
y_refresh = y[y.shape[0] // 2:]
Xy_refresh = xgb.DMatrix(X_refresh, y_refresh)
# The model will adapt to other half of the data by changing leaf value (no change in
# split condition) with refresh_leaf set to True.
refresh_result: xgb.callback.EvaluationMonitor.EvalsLog = {}
refreshed = xgb.train(
{"process_type": "update", "updater": "refresh", "refresh_leaf": True},
Xy_refresh,
num_boost_round=n_rounds,
xgb_model=booster,
evals=[(Xy, "Original"), (Xy_refresh, "Train")],
evals_result=refresh_result,
)
# Refresh the model without changing the leaf value, but tree statistic including
# cover and weight are refreshed.
refresh_result: xgb.callback.EvaluationMonitor.EvalsLog = {}
refreshed = xgb.train(
{"process_type": "update", "updater": "refresh", "refresh_leaf": False},
Xy_refresh,
num_boost_round=n_rounds,
xgb_model=booster,
evals=[(Xy, "Original"), (Xy_refresh, "Train")],
evals_result=refresh_result,
)
# Without refreshing the leaf value, resulting trees should be the same with original
# model except for accumulated statistic. The rtol is for floating point error in
# prediction.
np.testing.assert_allclose(
refresh_result["Original"]["rmse"], evals_result["Train"]["rmse"], rtol=1e-5
)
# But SHAP value is changed as cover in tree nodes are changed.
refreshed_SHAP = refreshed.predict(Xy, pred_contribs=True)
assert not np.allclose(SHAP, refreshed_SHAP, rtol=1e-3)
# Prune the trees with smaller max_depth
X_update = X_train
y_update = y_train
Xy_update = xgb.DMatrix(X_update, y_update)
prune_result: xgb.callback.EvaluationMonitor.EvalsLog = {}
pruned = xgb.train(
{"process_type": "update", "updater": "prune", "max_depth": 2},
Xy_update,
num_boost_round=n_rounds,
xgb_model=booster,
evals=[(Xy, "Original"), (Xy_update, "Train")],
evals_result=prune_result,
)
# Have a smaller model, but similar accuracy.
np.testing.assert_allclose(
np.array(prune_result["Original"]["rmse"]),
np.array(prune_result["Train"]["rmse"]),
atol=1e-5
)
if __name__ == "__main__":
main()

View File

@ -177,7 +177,7 @@ Parameters for Tree Booster
- ``grow_gpu_hist``: Grow tree with GPU. - ``grow_gpu_hist``: Grow tree with GPU.
- ``sync``: synchronizes trees in all distributed nodes. - ``sync``: synchronizes trees in all distributed nodes.
- ``refresh``: refreshes tree's statistics and/or leaf values based on the current data. Note that no random subsampling of data rows is performed. - ``refresh``: refreshes tree's statistics and/or leaf values based on the current data. Note that no random subsampling of data rows is performed.
- ``prune``: prunes the splits where loss < min_split_loss (or gamma). - ``prune``: prunes the splits where loss < min_split_loss (or gamma) and nodes that have depth greater than ``max_depth``.
- In a distributed setting, the implicit updater sequence value would be adjusted to ``grow_histmaker,prune`` by default, and you can set ``tree_method`` as ``hist`` to use ``grow_histmaker``. - In a distributed setting, the implicit updater sequence value would be adjusted to ``grow_histmaker,prune`` by default, and you can set ``tree_method`` as ``hist`` to use ``grow_histmaker``.

View File

@ -78,9 +78,24 @@ slight differences than expectation, which we are currently trying to overcome.
Other Updaters Other Updaters
************** **************
1. ``Pruner``: It prunes the built tree by ``gamma`` parameter. ``pruner`` is usually 1. ``Prune``: It prunes the existing trees. ``prune`` is usually used as part of other
used as part of other tree methods. tree methods. To use pruner independently, one needs to set the process type to update
2. ``Refresh``: Refresh the statistic of built trees on a new training dataset. by: ``{"process_type": "update", "updater": "prune"}``. With this set of parameters,
during trianing, XGBOost will prune the existing trees according to 2 parameters
``min_split_loss (gamma)`` and ``max_depth``.
2. ``Refresh``: Refresh the statistic of built trees on a new training dataset. Like the
pruner, To use refresh independently, one needs to set the process type to update:
``{"process_type": "update", "updater": "refresh"}``. During training, the updater will
change statistics like ``cover`` and ``weight`` according to the new training dataset.
When ``refresh_leaf`` is also set to true (default), XGBoost will update the leaf value
according to the new leaf weight, but the tree structure (split condition) itself
doesn't change.
There are examples on both training continuation (adding new trees) and using update
process on ``demo/guide-python``. Also checkout the ``process_type`` parameter in
:doc:`parameter`.
3. ``Sync``: Synchronize the tree among workers when running distributed training. 3. ``Sync``: Synchronize the tree among workers when running distributed training.
**************** ****************

View File

@ -14,6 +14,12 @@ def test_data_iterator():
subprocess.check_call(cmd) subprocess.check_call(cmd)
def test_update_process_demo():
script = os.path.join(td.PYTHON_DEMO_DIR, 'update_process.py')
cmd = ['python', script]
subprocess.check_call(cmd)
@pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())