Fix Python callback. (#6320)

This commit is contained in:
Jiaming Yuan 2020-10-30 05:03:44 +08:00 committed by GitHub
parent b181a88f9f
commit 6ff331b705
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 1 deletions

View File

@ -3,6 +3,8 @@
# pylint: disable=too-many-branches, too-many-statements # pylint: disable=too-many-branches, too-many-statements
"""Training Library containing training routines.""" """Training Library containing training routines."""
import warnings import warnings
import copy
import numpy as np import numpy as np
from .core import Booster, XGBoostError from .core import Booster, XGBoostError
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold) from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
@ -57,7 +59,7 @@ def _train_internal(params, dtrain,
evals_result=None, maximize=None, evals_result=None, maximize=None,
verbose_eval=None, early_stopping_rounds=None): verbose_eval=None, early_stopping_rounds=None):
"""internal training function""" """internal training function"""
callbacks = [] if callbacks is None else callbacks callbacks = [] if callbacks is None else copy.copy(callbacks)
evals = list(evals) evals = list(evals)
params = _configure_metrics(params.copy()) params = _configure_metrics(params.copy())

View File

@ -232,3 +232,16 @@ class TestCallbacks(unittest.TestCase):
for i in range(1, 10): for i in range(1, 10):
assert os.path.exists( assert os.path.exists(
os.path.join(tmpdir, 'model_' + str(i) + '.pkl')) os.path.join(tmpdir, 'model_' + str(i) + '.pkl'))
def test_callback_list(self):
X, y = tm.get_boston()
m = xgb.DMatrix(X, y)
callbacks = [xgb.callback.EarlyStopping(rounds=10)]
for i in range(4):
xgb.train({'objective': 'reg:squarederror',
'eval_metric': 'rmse'}, m,
evals=[(m, 'Train')],
num_boost_round=1,
verbose_eval=True,
callbacks=callbacks)
assert len(callbacks) == 1