Merge pull request #1 from quansie/quansie-python-training-patch-1
training.py - pass all eval_metric information to evals_result
This commit is contained in:
commit
8a484e990e
@ -56,7 +56,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
else:
|
else:
|
||||||
evals_name = [d[1] for d in evals]
|
evals_name = [d[1] for d in evals]
|
||||||
evals_result.clear()
|
evals_result.clear()
|
||||||
evals_result.update({key: [] for key in evals_name})
|
evals_result.update({key: {} for key in evals_name})
|
||||||
|
|
||||||
if not early_stopping_rounds:
|
if not early_stopping_rounds:
|
||||||
for i in range(num_boost_round):
|
for i in range(num_boost_round):
|
||||||
@ -71,9 +71,18 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
if verbose_eval:
|
if verbose_eval:
|
||||||
sys.stderr.write(msg + '\n')
|
sys.stderr.write(msg + '\n')
|
||||||
if evals_result is not None:
|
if evals_result is not None:
|
||||||
res = re.findall(":-?([0-9.]+).", msg)
|
res = re.findall("([0-9a-zA-Z@]+[-]*):-?([0-9.]+).", msg)
|
||||||
for key, val in zip(evals_name, res):
|
for key in evals_name:
|
||||||
evals_result[key].append(val)
|
evals_idx = evals_name.index(key)
|
||||||
|
res_per_eval = len(res) / len(evals_name)
|
||||||
|
for r in range(res_per_eval):
|
||||||
|
res_item = res[(evals_idx*res_per_eval) + r]
|
||||||
|
res_key = res_item[0]
|
||||||
|
res_val = res_item[1]
|
||||||
|
if res_key in evals_result[key]:
|
||||||
|
evals_result[key][res_key].append(res_val)
|
||||||
|
else:
|
||||||
|
evals_result[key][res_key] = [res_val]
|
||||||
return bst
|
return bst
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -119,9 +128,18 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
sys.stderr.write(msg + '\n')
|
sys.stderr.write(msg + '\n')
|
||||||
|
|
||||||
if evals_result is not None:
|
if evals_result is not None:
|
||||||
res = re.findall(":-?([0-9.]+).", msg)
|
res = re.findall("([0-9a-zA-Z@]+[-]*):-?([0-9.]+).", msg)
|
||||||
for key, val in zip(evals_name, res):
|
for key in evals_name:
|
||||||
evals_result[key].append(val)
|
evals_idx = evals_name.index(key)
|
||||||
|
res_per_eval = len(res) / len(evals_name)
|
||||||
|
for r in range(res_per_eval):
|
||||||
|
res_item = res[(evals_idx*res_per_eval) + r]
|
||||||
|
res_key = res_item[0]
|
||||||
|
res_val = res_item[1]
|
||||||
|
if res_key in evals_result[key]:
|
||||||
|
evals_result[key][res_key].append(res_val)
|
||||||
|
else:
|
||||||
|
evals_result[key][res_key] = [res_val]
|
||||||
|
|
||||||
score = float(msg.rsplit(':', 1)[1])
|
score = float(msg.rsplit(':', 1)[1])
|
||||||
if (maximize_score and score > best_score) or \
|
if (maximize_score and score > best_score) or \
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user