Fix Python callback. (#6320)
This commit is contained in:
parent
b181a88f9f
commit
6ff331b705
@ -3,6 +3,8 @@
|
|||||||
# pylint: disable=too-many-branches, too-many-statements
|
# pylint: disable=too-many-branches, too-many-statements
|
||||||
"""Training Library containing training routines."""
|
"""Training Library containing training routines."""
|
||||||
import warnings
|
import warnings
|
||||||
|
import copy
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .core import Booster, XGBoostError
|
from .core import Booster, XGBoostError
|
||||||
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
||||||
@ -57,7 +59,7 @@ def _train_internal(params, dtrain,
|
|||||||
evals_result=None, maximize=None,
|
evals_result=None, maximize=None,
|
||||||
verbose_eval=None, early_stopping_rounds=None):
|
verbose_eval=None, early_stopping_rounds=None):
|
||||||
"""internal training function"""
|
"""internal training function"""
|
||||||
callbacks = [] if callbacks is None else callbacks
|
callbacks = [] if callbacks is None else copy.copy(callbacks)
|
||||||
evals = list(evals)
|
evals = list(evals)
|
||||||
params = _configure_metrics(params.copy())
|
params = _configure_metrics(params.copy())
|
||||||
|
|
||||||
|
|||||||
@ -232,3 +232,16 @@ class TestCallbacks(unittest.TestCase):
|
|||||||
for i in range(1, 10):
|
for i in range(1, 10):
|
||||||
assert os.path.exists(
|
assert os.path.exists(
|
||||||
os.path.join(tmpdir, 'model_' + str(i) + '.pkl'))
|
os.path.join(tmpdir, 'model_' + str(i) + '.pkl'))
|
||||||
|
|
||||||
|
def test_callback_list(self):
|
||||||
|
X, y = tm.get_boston()
|
||||||
|
m = xgb.DMatrix(X, y)
|
||||||
|
callbacks = [xgb.callback.EarlyStopping(rounds=10)]
|
||||||
|
for i in range(4):
|
||||||
|
xgb.train({'objective': 'reg:squarederror',
|
||||||
|
'eval_metric': 'rmse'}, m,
|
||||||
|
evals=[(m, 'Train')],
|
||||||
|
num_boost_round=1,
|
||||||
|
verbose_eval=True,
|
||||||
|
callbacks=callbacks)
|
||||||
|
assert len(callbacks) == 1
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user