diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index 5bb6377c5..11dcfdb4b 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -11,6 +11,7 @@ from __future__ import absolute_import import os import sys +import re import ctypes import collections @@ -530,7 +531,7 @@ class Booster(object): return fmap -def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, early_stopping_rounds=None): +def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, early_stopping_rounds=None,evals_result=None): """ Train a booster with given parameters. @@ -542,7 +543,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea Data to be trained. num_boost_round: int Number of boosting iterations. - watchlist : list of pairs (DMatrix, string) + watchlist (evals): list of pairs (DMatrix, string) List of items to be evaluated during training, this allows user to watch performance on the validation set. obj : function @@ -557,6 +558,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea Returns the model from the last iteration (not the best one). If early stopping occurs, the model will have two additional fields: bst.best_score and bst.best_iteration. + evals_result: dict + This dictionary stores the evaluation results of all the items in watchlist Returns ------- @@ -566,15 +569,39 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea evals = list(evals) bst = Booster(params, [dtrain] + [d[0] for d in evals]) + if evals_result is not None: + if type(evals_result) is not dict: + raise TypeError('evals_result has to be a dictionary') + else: + evals_name = [d[1] for d in evals] + evals_result.clear() + evals_result.update({key:[] for key in evals_name}) + if not early_stopping_rounds: for i in range(num_boost_round): bst.update(dtrain, i, obj) if len(evals) != 0: bst_eval_set = bst.eval_set(evals, i, feval) if isinstance(bst_eval_set, string_types): - sys.stderr.write(bst_eval_set + '\n') + msg = bst_eval_set + #sys.stderr.write(bst_eval_set + '\n') + # if evals_result is not None: + # res = re.findall(":([0-9.]+).",bst_eval_set) + # for key,val in zip(evals_name,res): + # evals_result[key].append(val) else: - sys.stderr.write(bst_eval_set.decode() + '\n') + msg = bst_eval_set.decode() + # sys.stderr.write(bst_eval_set.decode() + '\n') + # if evals_result is not None: + # res = re.findall(":([0-9.]+).",bst_eval_set.decode()) + # for key,val in zip(evals_name,res): + # evals_result[key].append(val) + + sys.stderr.write(msg + '\n') + if evals_result is not None: + res = re.findall(":([0-9.]+).",msg) + for key,val in zip(evals_name,res): + evals_result[key].append(val) return bst else: @@ -617,6 +644,11 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea sys.stderr.write(msg + '\n') + if evals_result is not None: + res = re.findall(":([0-9.]+).",msg) + for key,val in zip(evals_name,res): + evals_result[key].append(val) + score = float(msg.rsplit(':', 1)[1]) if (maximize_score and score > best_score) or \ (not maximize_score and score < best_score):