diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index c81935e9d..1da738e89 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -105,6 +105,15 @@ class TestModels(unittest.TestCase): if int(preds2[i] > 0.5) != labels[i]) / float(len(preds2)) assert err == err2 + def test_multi_eval_metric(self): + watchlist = [(dtest, 'eval'), (dtrain, 'train')] + param = {'max_depth': 2, 'eta': 0.2, 'silent': 1, 'objective': 'binary:logistic'} + param['eval_metric'] = ["auc", "logloss", 'error'] + evals_result = {} + bst = xgb.train(param, dtrain, 4, watchlist, evals_result=evals_result) + assert len(evals_result['eval']) == 3 + assert set(evals_result['eval'].keys()) == {'auc', 'error', 'logloss'} + def test_fpreproc(self): param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}