diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index fd7565e52..bfb39e837 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -361,7 +361,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(), Number of boosting iterations. nfold : int Number of folds in CV. - metrics : list of strings + metrics : string or list of strings Evaluation metrics to be watched in CV. obj : function Custom objective function. @@ -394,9 +394,28 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(), ------- evaluation history : list(string) """ + if isinstance(metrics, str): + metrics = [metrics] + + if isinstance(params, list): + _metrics = [x[1] for x in params if x[0] == 'eval_metric'] + params = dict(params) + if 'eval_metric' in params: + params['eval_metric'] = _metrics + else: + params= dict((k, v) for k, v in params.items()) + + if len(metrics) == 0 and 'eval_metric' in params: + if isinstance(params['eval_metric'], list): + metrics = params['eval_metric'] + else: + metrics = [params['eval_metric']] + + params.pop("eval_metric", None) + if early_stopping_rounds is not None: if len(metrics) > 1: - raise ValueError('Check your params.'\ + raise ValueError('Check your params. '\ 'Early stopping works with single eval metric only.') sys.stderr.write("Will train until cv error hasn't decreased in {} rounds.\n".format(\ @@ -434,7 +453,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(), best_score_i = i elif i - best_score_i >= early_stopping_rounds: results = results[:best_score_i+1] - sys.stderr.write("Stopping. Best iteration: {} (mean: {}, std: {})\n". + sys.stderr.write("Stopping. Best iteration:\n[{}] cv-mean:{}\tcv-std:{}\n". format(best_score_i, results[-1][0], results[-1][1])) break if as_pandas: