record training progress
This commit is contained in:
parent
4aa1ea2d44
commit
44d1043031
@ -11,6 +11,7 @@ from __future__ import absolute_import
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import re
|
||||||
import ctypes
|
import ctypes
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
@ -530,7 +531,7 @@ class Booster(object):
|
|||||||
return fmap
|
return fmap
|
||||||
|
|
||||||
|
|
||||||
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, early_stopping_rounds=None):
|
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, early_stopping_rounds=None,evals_result=None):
|
||||||
"""
|
"""
|
||||||
Train a booster with given parameters.
|
Train a booster with given parameters.
|
||||||
|
|
||||||
@ -542,7 +543,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea
|
|||||||
Data to be trained.
|
Data to be trained.
|
||||||
num_boost_round: int
|
num_boost_round: int
|
||||||
Number of boosting iterations.
|
Number of boosting iterations.
|
||||||
watchlist : list of pairs (DMatrix, string)
|
watchlist (evals): list of pairs (DMatrix, string)
|
||||||
List of items to be evaluated during training, this allows user to watch
|
List of items to be evaluated during training, this allows user to watch
|
||||||
performance on the validation set.
|
performance on the validation set.
|
||||||
obj : function
|
obj : function
|
||||||
@ -557,6 +558,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea
|
|||||||
Returns the model from the last iteration (not the best one).
|
Returns the model from the last iteration (not the best one).
|
||||||
If early stopping occurs, the model will have two additional fields:
|
If early stopping occurs, the model will have two additional fields:
|
||||||
bst.best_score and bst.best_iteration.
|
bst.best_score and bst.best_iteration.
|
||||||
|
evals_result: dict
|
||||||
|
This dictionary stores the evaluation results of all the items in watchlist
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -566,15 +569,39 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea
|
|||||||
evals = list(evals)
|
evals = list(evals)
|
||||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||||
|
|
||||||
|
if evals_result is not None:
|
||||||
|
if type(evals_result) is not dict:
|
||||||
|
raise TypeError('evals_result has to be a dictionary')
|
||||||
|
else:
|
||||||
|
evals_name = [d[1] for d in evals]
|
||||||
|
evals_result.clear()
|
||||||
|
evals_result.update({key:[] for key in evals_name})
|
||||||
|
|
||||||
if not early_stopping_rounds:
|
if not early_stopping_rounds:
|
||||||
for i in range(num_boost_round):
|
for i in range(num_boost_round):
|
||||||
bst.update(dtrain, i, obj)
|
bst.update(dtrain, i, obj)
|
||||||
if len(evals) != 0:
|
if len(evals) != 0:
|
||||||
bst_eval_set = bst.eval_set(evals, i, feval)
|
bst_eval_set = bst.eval_set(evals, i, feval)
|
||||||
if isinstance(bst_eval_set, string_types):
|
if isinstance(bst_eval_set, string_types):
|
||||||
sys.stderr.write(bst_eval_set + '\n')
|
msg = bst_eval_set
|
||||||
|
#sys.stderr.write(bst_eval_set + '\n')
|
||||||
|
# if evals_result is not None:
|
||||||
|
# res = re.findall(":([0-9.]+).",bst_eval_set)
|
||||||
|
# for key,val in zip(evals_name,res):
|
||||||
|
# evals_result[key].append(val)
|
||||||
else:
|
else:
|
||||||
sys.stderr.write(bst_eval_set.decode() + '\n')
|
msg = bst_eval_set.decode()
|
||||||
|
# sys.stderr.write(bst_eval_set.decode() + '\n')
|
||||||
|
# if evals_result is not None:
|
||||||
|
# res = re.findall(":([0-9.]+).",bst_eval_set.decode())
|
||||||
|
# for key,val in zip(evals_name,res):
|
||||||
|
# evals_result[key].append(val)
|
||||||
|
|
||||||
|
sys.stderr.write(msg + '\n')
|
||||||
|
if evals_result is not None:
|
||||||
|
res = re.findall(":([0-9.]+).",msg)
|
||||||
|
for key,val in zip(evals_name,res):
|
||||||
|
evals_result[key].append(val)
|
||||||
return bst
|
return bst
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -617,6 +644,11 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, ea
|
|||||||
|
|
||||||
sys.stderr.write(msg + '\n')
|
sys.stderr.write(msg + '\n')
|
||||||
|
|
||||||
|
if evals_result is not None:
|
||||||
|
res = re.findall(":([0-9.]+).",msg)
|
||||||
|
for key,val in zip(evals_name,res):
|
||||||
|
evals_result[key].append(val)
|
||||||
|
|
||||||
score = float(msg.rsplit(':', 1)[1])
|
score = float(msg.rsplit(':', 1)[1])
|
||||||
if (maximize_score and score > best_score) or \
|
if (maximize_score and score > best_score) or \
|
||||||
(not maximize_score and score < best_score):
|
(not maximize_score and score < best_score):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user