python cv bugfixing

- fixed bug if both eval_metrics xgb-param and
metrics param of cv function have been set
- cv early stopping output looks now like the one of xgb.train
This commit is contained in:
FrozenFingerz 2015-12-29 12:24:38 +01:00
parent 4f43f1d0ac
commit 2a46918c66

View File

@ -361,7 +361,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
Number of boosting iterations. Number of boosting iterations.
nfold : int nfold : int
Number of folds in CV. Number of folds in CV.
metrics : list of strings metrics : string or list of strings
Evaluation metrics to be watched in CV. Evaluation metrics to be watched in CV.
obj : function obj : function
Custom objective function. Custom objective function.
@ -394,6 +394,25 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
------- -------
evaluation history : list(string) evaluation history : list(string)
""" """
if isinstance(metrics, str):
metrics = [metrics]
if isinstance(params, list):
_metrics = [x[1] for x in params if x[0] == 'eval_metric']
params = dict(params)
if 'eval_metric' in params:
params['eval_metric'] = _metrics
else:
params= dict((k, v) for k, v in params.items())
if len(metrics) == 0 and 'eval_metric' in params:
if isinstance(params['eval_metric'], list):
metrics = params['eval_metric']
else:
metrics = [params['eval_metric']]
params.pop("eval_metric", None)
if early_stopping_rounds is not None: if early_stopping_rounds is not None:
if len(metrics) > 1: if len(metrics) > 1:
raise ValueError('Check your params. '\ raise ValueError('Check your params. '\
@ -434,7 +453,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
best_score_i = i best_score_i = i
elif i - best_score_i >= early_stopping_rounds: elif i - best_score_i >= early_stopping_rounds:
results = results[:best_score_i+1] results = results[:best_score_i+1]
sys.stderr.write("Stopping. Best iteration: {} (mean: {}, std: {})\n". sys.stderr.write("Stopping. Best iteration:\n[{}] cv-mean:{}\tcv-std:{}\n".
format(best_score_i, results[-1][0], results[-1][1])) format(best_score_i, results[-1][0], results[-1][1]))
break break
if as_pandas: if as_pandas: