record training progress

This commit is contained in:
Yizheng Liao 2015-04-23 21:24:24 -07:00
parent 4aa1ea2d44
commit 44d1043031

View File

@ -11,6 +11,7 @@ from __future__ import absolute_import
import os import os
import sys import sys
import re
import ctypes import ctypes
import collections import collections
@ -530,7 +531,7 @@ class Booster(object):
return fmap return fmap
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, early_stopping_rounds=None): def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, early_stopping_rounds=None,evals_result=None):
""" """
Train a booster with given parameters. Train a booster with given parameters.
@ -542,7 +543,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea
Data to be trained. Data to be trained.
num_boost_round: int num_boost_round: int
Number of boosting iterations. Number of boosting iterations.
watchlist : list of pairs (DMatrix, string) watchlist (evals): list of pairs (DMatrix, string)
List of items to be evaluated during training, this allows user to watch List of items to be evaluated during training, this allows user to watch
performance on the validation set. performance on the validation set.
obj : function obj : function
@ -557,6 +558,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea
Returns the model from the last iteration (not the best one). Returns the model from the last iteration (not the best one).
If early stopping occurs, the model will have two additional fields: If early stopping occurs, the model will have two additional fields:
bst.best_score and bst.best_iteration. bst.best_score and bst.best_iteration.
evals_result: dict
This dictionary stores the evaluation results of all the items in watchlist
Returns Returns
------- -------
@ -566,15 +569,39 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea
evals = list(evals) evals = list(evals)
bst = Booster(params, [dtrain] + [d[0] for d in evals]) bst = Booster(params, [dtrain] + [d[0] for d in evals])
if evals_result is not None:
if type(evals_result) is not dict:
raise TypeError('evals_result has to be a dictionary')
else:
evals_name = [d[1] for d in evals]
evals_result.clear()
evals_result.update({key:[] for key in evals_name})
if not early_stopping_rounds: if not early_stopping_rounds:
for i in range(num_boost_round): for i in range(num_boost_round):
bst.update(dtrain, i, obj) bst.update(dtrain, i, obj)
if len(evals) != 0: if len(evals) != 0:
bst_eval_set = bst.eval_set(evals, i, feval) bst_eval_set = bst.eval_set(evals, i, feval)
if isinstance(bst_eval_set, string_types): if isinstance(bst_eval_set, string_types):
sys.stderr.write(bst_eval_set + '\n') msg = bst_eval_set
#sys.stderr.write(bst_eval_set + '\n')
# if evals_result is not None:
# res = re.findall(":([0-9.]+).",bst_eval_set)
# for key,val in zip(evals_name,res):
# evals_result[key].append(val)
else: else:
sys.stderr.write(bst_eval_set.decode() + '\n') msg = bst_eval_set.decode()
# sys.stderr.write(bst_eval_set.decode() + '\n')
# if evals_result is not None:
# res = re.findall(":([0-9.]+).",bst_eval_set.decode())
# for key,val in zip(evals_name,res):
# evals_result[key].append(val)
sys.stderr.write(msg + '\n')
if evals_result is not None:
res = re.findall(":([0-9.]+).",msg)
for key,val in zip(evals_name,res):
evals_result[key].append(val)
return bst return bst
else: else:
@ -617,6 +644,11 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea
sys.stderr.write(msg + '\n') sys.stderr.write(msg + '\n')
if evals_result is not None:
res = re.findall(":([0-9.]+).",msg)
for key,val in zip(evals_name,res):
evals_result[key].append(val)
score = float(msg.rsplit(':', 1)[1]) score = float(msg.rsplit(':', 1)[1])
if (maximize_score and score > best_score) or \ if (maximize_score and score > best_score) or \
(not maximize_score and score < best_score): (not maximize_score and score < best_score):