python: multiple eval_metrics changes

- allows feval to return a list of tuples (name, error/score value)
- changed behavior for multiple eval_metrics in conjunction with
early_stopping: Instead of raising an error, the last passed evel_metric
(or last entry in return value of feval) is used for early stopping
- allows list of eval_metrics in dict-typed params
- unittest for new features / behavior

documentation updated

- example for assigning a list to 'eval_metric'
- note about early stopping on last passed eval metric

- info msg for used eval metric added
This commit is contained in:
FrozenFingerz 2015-11-03 11:22:00 +01:00
parent 190e58a8c6
commit b59018aa05
4 changed files with 129 additions and 10 deletions

View File

@ -67,10 +67,17 @@ XGBoost use list of pair to save [parameters](../parameter.md). Eg
```python
param = {'bst:max_depth':2, 'bst:eta':1, 'silent':1, 'objective':'binary:logistic' }
param['nthread'] = 4
plst = param.items()
plst += [('eval_metric', 'auc')] # Multiple evals can be handled in this way
plst += [('eval_metric', 'ams@0')]
param['eval_metric'] = 'auc'
```
* You can also specify multiple eval metrics:
```python
param['eval_metric'] = ['auc', 'ams@0']
# alternativly:
# plst = param.items()
# plst += [('eval_metric', 'ams@0')]
```
* Specify validations set to watch performance
```python
evallist = [(dtest,'eval'), (dtrain,'train')]
@ -116,7 +123,7 @@ The model will train until the validation score stops improving. Validation erro
If early stopping occurs, the model will have two additional fields: `bst.best_score` and `bst.best_iteration`. Note that `train()` will return a model from the last iteration, not the best one.
This works with both metrics to minimize (RMSE, log loss, etc.) and to maximize (MAP, NDCG, AUC).
This works with both metrics to minimize (RMSE, log loss, etc.) and to maximize (MAP, NDCG, AUC). Note that if you specify more than one evaluation metric the last one in `param['eval_metric']` is used for early stopping.
Prediction
----------

View File

@ -745,8 +745,13 @@ class Booster(object):
else:
res = '[%d]' % iteration
for dmat, evname in evals:
name, val = feval(self.predict(dmat), dmat)
res += '\t%s-%s:%f' % (evname, name, val)
feval_ret = feval(self.predict(dmat), dmat)
if isinstance(feval_ret, list):
for name, val in feval_ret:
res += '\t%s-%s:%f' % (evname, name, val)
else:
name, val = feval_ret
res += '\t%s-%s:%f' % (evname, name, val)
return res
def eval(self, data, name='eval', iteration=0):

View File

@ -61,6 +61,17 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
booster : a trained booster model
"""
evals = list(evals)
if isinstance(params, dict) \
and 'eval_metric' in params \
and isinstance(params['eval_metric'], list):
params = dict((k, v) for k, v in params.items())
eval_metrics = params['eval_metric']
params.pop("eval_metric", None)
params = list(params.items())
for eval_metric in eval_metrics:
params += [('eval_metric', eval_metric)]
bst = Booster(params, [dtrain] + [d[0] for d in evals])
ntrees = 0
if xgb_model is not None:
if not isinstance(xgb_model, STRING_TYPES):
@ -70,7 +81,6 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
else:
bst = Booster(params, [dtrain] + [d[0] for d in evals])
if evals_result is not None:
if not isinstance(evals_result, dict):
raise TypeError('evals_result has to be a dictionary')
@ -120,9 +130,11 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
# is params a list of tuples? are we using multiple eval metrics?
if isinstance(params, list):
if len(params) != len(dict(params).items()):
raise ValueError('Check your params.'\
'Early stopping works with single eval metric only.')
params = dict(params)
params = dict(params)
sys.stderr.write("Multiple eval metrics has been passed: " \
"'{0}' will be used for early stopping.\n\n".format(params['eval_metric']))
else:
params = dict(params)
# either minimize loss or maximize AUC/MAP/NDCG
maximize_score = False

View File

@ -0,0 +1,95 @@
import xgboost as xgb
import numpy as np
from sklearn.cross_validation import KFold, train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.grid_search import GridSearchCV
from sklearn.datasets import load_iris, load_digits, load_boston
import unittest
rng = np.random.RandomState(1337)
class TestEvalMetrics(unittest.TestCase):
xgb_params_01 = {
'silent': 1,
'nthread': 1,
'eval_metric': 'error'
}
xgb_params_02 = {
'silent': 1,
'nthread': 1,
'eval_metric': ['error']
}
xgb_params_03 = {
'silent': 1,
'nthread': 1,
'eval_metric': ['rmse', 'error']
}
xgb_params_04 = {
'silent': 1,
'nthread': 1,
'eval_metric': ['error', 'rmse']
}
def evalerror_01(self, preds, dtrain):
labels = dtrain.get_label()
return 'error', float(sum(labels != (preds > 0.0))) / len(labels)
def evalerror_02(self, preds, dtrain):
labels = dtrain.get_label()
return [('error', float(sum(labels != (preds > 0.0))) / len(labels))]
def evalerror_03(self, preds, dtrain):
labels = dtrain.get_label()
return [('rmse', mean_squared_error(labels, preds)),
('error', float(sum(labels != (preds > 0.0))) / len(labels))]
def evalerror_04(self, preds, dtrain):
labels = dtrain.get_label()
return [('error', float(sum(labels != (preds > 0.0))) / len(labels)),
('rmse', mean_squared_error(labels, preds))]
def test_eval_metrics(self):
digits = load_digits(2)
X = digits['data']
y = digits['target']
Xt, Xv, yt, yv = train_test_split(X, y, test_size=0.2, random_state=0)
dtrain = xgb.DMatrix(Xt, label=yt)
dvalid = xgb.DMatrix(Xv, label=yv)
watchlist = [(dtrain, 'train'), (dvalid, 'val')]
gbdt_01 = xgb.train(self.xgb_params_01, dtrain, num_boost_round=10)
gbdt_02 = xgb.train(self.xgb_params_02, dtrain, num_boost_round=10)
gbdt_03 = xgb.train(self.xgb_params_03, dtrain, num_boost_round=10)
assert gbdt_01.predict(dvalid)[0] == gbdt_02.predict(dvalid)[0]
assert gbdt_01.predict(dvalid)[0] == gbdt_03.predict(dvalid)[0]
gbdt_01 = xgb.train(self.xgb_params_01, dtrain, 10, watchlist,
early_stopping_rounds=2)
gbdt_02 = xgb.train(self.xgb_params_02, dtrain, 10, watchlist,
early_stopping_rounds=2)
gbdt_03 = xgb.train(self.xgb_params_03, dtrain, 10, watchlist,
early_stopping_rounds=2)
gbdt_04 = xgb.train(self.xgb_params_04, dtrain, 10, watchlist,
early_stopping_rounds=2)
assert gbdt_01.predict(dvalid)[0] == gbdt_02.predict(dvalid)[0]
assert gbdt_01.predict(dvalid)[0] == gbdt_03.predict(dvalid)[0]
assert gbdt_03.predict(dvalid)[0] != gbdt_04.predict(dvalid)[0]
gbdt_01 = xgb.train(self.xgb_params_01, dtrain, 10, watchlist,
early_stopping_rounds=2, feval=self.evalerror_01)
gbdt_02 = xgb.train(self.xgb_params_02, dtrain, 10, watchlist,
early_stopping_rounds=2, feval=self.evalerror_02)
gbdt_03 = xgb.train(self.xgb_params_03, dtrain, 10, watchlist,
early_stopping_rounds=2, feval=self.evalerror_03)
gbdt_04 = xgb.train(self.xgb_params_04, dtrain, 10, watchlist,
early_stopping_rounds=2, feval=self.evalerror_04)
assert gbdt_01.predict(dvalid)[0] == gbdt_02.predict(dvalid)[0]
assert gbdt_01.predict(dvalid)[0] == gbdt_03.predict(dvalid)[0]
assert gbdt_03.predict(dvalid)[0] != gbdt_04.predict(dvalid)[0]