Fix Python callback. (#6320)
This commit is contained in:
parent
b181a88f9f
commit
6ff331b705
@ -3,6 +3,8 @@
|
||||
# pylint: disable=too-many-branches, too-many-statements
|
||||
"""Training Library containing training routines."""
|
||||
import warnings
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
from .core import Booster, XGBoostError
|
||||
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
||||
@ -57,7 +59,7 @@ def _train_internal(params, dtrain,
|
||||
evals_result=None, maximize=None,
|
||||
verbose_eval=None, early_stopping_rounds=None):
|
||||
"""internal training function"""
|
||||
callbacks = [] if callbacks is None else callbacks
|
||||
callbacks = [] if callbacks is None else copy.copy(callbacks)
|
||||
evals = list(evals)
|
||||
params = _configure_metrics(params.copy())
|
||||
|
||||
|
||||
@ -232,3 +232,16 @@ class TestCallbacks(unittest.TestCase):
|
||||
for i in range(1, 10):
|
||||
assert os.path.exists(
|
||||
os.path.join(tmpdir, 'model_' + str(i) + '.pkl'))
|
||||
|
||||
def test_callback_list(self):
|
||||
X, y = tm.get_boston()
|
||||
m = xgb.DMatrix(X, y)
|
||||
callbacks = [xgb.callback.EarlyStopping(rounds=10)]
|
||||
for i in range(4):
|
||||
xgb.train({'objective': 'reg:squarederror',
|
||||
'eval_metric': 'rmse'}, m,
|
||||
evals=[(m, 'Train')],
|
||||
num_boost_round=1,
|
||||
verbose_eval=True,
|
||||
callbacks=callbacks)
|
||||
assert len(callbacks) == 1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user