Merge pull request #257 from yzliao/master

Python: record evaluation results in train()
This commit is contained in:
Tianqi Chen 2015-04-23 21:51:09 -07:00
commit b94f7b0849

View File

@ -11,6 +11,7 @@ from __future__ import absolute_import
import os import os
import sys import sys
import re
import ctypes import ctypes
import collections import collections
@ -530,7 +531,8 @@ class Booster(object):
return fmap return fmap
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, early_stopping_rounds=None): def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
early_stopping_rounds=None,evals_result=None):
""" """
Train a booster with given parameters. Train a booster with given parameters.
@ -542,7 +544,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea
Data to be trained. Data to be trained.
num_boost_round: int num_boost_round: int
Number of boosting iterations. Number of boosting iterations.
watchlist : list of pairs (DMatrix, string) watchlist (evals): list of pairs (DMatrix, string)
List of items to be evaluated during training, this allows user to watch List of items to be evaluated during training, this allows user to watch
performance on the validation set. performance on the validation set.
obj : function obj : function
@ -557,6 +559,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea
Returns the model from the last iteration (not the best one). Returns the model from the last iteration (not the best one).
If early stopping occurs, the model will have two additional fields: If early stopping occurs, the model will have two additional fields:
bst.best_score and bst.best_iteration. bst.best_score and bst.best_iteration.
evals_result: dict
This dictionary stores the evaluation results of all the items in watchlist
Returns Returns
------- -------
@ -566,15 +570,29 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea
evals = list(evals) evals = list(evals)
bst = Booster(params, [dtrain] + [d[0] for d in evals]) bst = Booster(params, [dtrain] + [d[0] for d in evals])
if evals_result is not None:
if type(evals_result) is not dict:
raise TypeError('evals_result has to be a dictionary')
else:
evals_name = [d[1] for d in evals]
evals_result.clear()
evals_result.update({key:[] for key in evals_name})
if not early_stopping_rounds: if not early_stopping_rounds:
for i in range(num_boost_round): for i in range(num_boost_round):
bst.update(dtrain, i, obj) bst.update(dtrain, i, obj)
if len(evals) != 0: if len(evals) != 0:
bst_eval_set = bst.eval_set(evals, i, feval) bst_eval_set = bst.eval_set(evals, i, feval)
if isinstance(bst_eval_set, string_types): if isinstance(bst_eval_set, string_types):
sys.stderr.write(bst_eval_set + '\n') msg = bst_eval_set
else: else:
sys.stderr.write(bst_eval_set.decode() + '\n') msg = bst_eval_set.decode()
sys.stderr.write(msg + '\n')
if evals_result is not None:
res = re.findall(":([0-9.]+).",msg)
for key,val in zip(evals_name,res):
evals_result[key].append(val)
return bst return bst
else: else:
@ -617,6 +635,11 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea
sys.stderr.write(msg + '\n') sys.stderr.write(msg + '\n')
if evals_result is not None:
res = re.findall(":([0-9.]+).",msg)
for key,val in zip(evals_name,res):
evals_result[key].append(val)
score = float(msg.rsplit(':', 1)[1]) score = float(msg.rsplit(':', 1)[1])
if (maximize_score and score > best_score) or \ if (maximize_score and score > best_score) or \
(not maximize_score and score < best_score): (not maximize_score and score < best_score):