Merge pull request #712 from Far0n/py_cv

python cv bugfixing (eval metrics)
This commit is contained in:
Yuan (Terry) Tang 2015-12-29 07:30:26 -06:00
commit d747649892
2 changed files with 84 additions and 23 deletions

View File

@ -361,7 +361,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
Number of boosting iterations. Number of boosting iterations.
nfold : int nfold : int
Number of folds in CV. Number of folds in CV.
metrics : list of strings metrics : string or list of strings
Evaluation metrics to be watched in CV. Evaluation metrics to be watched in CV.
obj : function obj : function
Custom objective function. Custom objective function.
@ -394,9 +394,28 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
------- -------
evaluation history : list(string) evaluation history : list(string)
""" """
if isinstance(metrics, str):
metrics = [metrics]
if isinstance(params, list):
_metrics = [x[1] for x in params if x[0] == 'eval_metric']
params = dict(params)
if 'eval_metric' in params:
params['eval_metric'] = _metrics
else:
params= dict((k, v) for k, v in params.items())
if len(metrics) == 0 and 'eval_metric' in params:
if isinstance(params['eval_metric'], list):
metrics = params['eval_metric']
else:
metrics = [params['eval_metric']]
params.pop("eval_metric", None)
if early_stopping_rounds is not None: if early_stopping_rounds is not None:
if len(metrics) > 1: if len(metrics) > 1:
raise ValueError('Check your params.'\ raise ValueError('Check your params. '\
'Early stopping works with single eval metric only.') 'Early stopping works with single eval metric only.')
sys.stderr.write("Will train until cv error hasn't decreased in {} rounds.\n".format(\ sys.stderr.write("Will train until cv error hasn't decreased in {} rounds.\n".format(\
@ -434,7 +453,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
best_score_i = i best_score_i = i
elif i - best_score_i >= early_stopping_rounds: elif i - best_score_i >= early_stopping_rounds:
results = results[:best_score_i+1] results = results[:best_score_i+1]
sys.stderr.write("Stopping. Best iteration: {} (mean: {}, std: {})\n". sys.stderr.write("Stopping. Best iteration:\n[{}] cv-mean:{}\tcv-std:{}\n".
format(best_score_i, results[-1][0], results[-1][1])) format(best_score_i, results[-1][0], results[-1][1]))
break break
if as_pandas: if as_pandas:

View File

@ -4,25 +4,26 @@ import xgboost as xgb
import unittest import unittest
import matplotlib import matplotlib
matplotlib.use('Agg') matplotlib.use('Agg')
dpath = 'demo/data/' dpath = 'demo/data/'
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
class TestBasic(unittest.TestCase):
class TestBasic(unittest.TestCase):
def test_basic(self): def test_basic(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic' } param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
# specify validations set to watch performance # specify validations set to watch performance
watchlist = [(dtest,'eval'), (dtrain,'train')] watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 2 num_round = 2
bst = xgb.train(param, dtrain, num_round, watchlist) bst = xgb.train(param, dtrain, num_round, watchlist)
# this is prediction # this is prediction
preds = bst.predict(dtest) preds = bst.predict(dtest)
labels = dtest.get_label() labels = dtest.get_label()
err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds)) err = sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
# error must be smaller than 10% # error must be smaller than 10%
assert err < 0.1 assert err < 0.1
@ -35,7 +36,7 @@ class TestBasic(unittest.TestCase):
dtest2 = xgb.DMatrix('dtest.buffer') dtest2 = xgb.DMatrix('dtest.buffer')
preds2 = bst2.predict(dtest2) preds2 = bst2.predict(dtest2)
# assert they are the same # assert they are the same
assert np.sum(np.abs(preds2-preds)) == 0 assert np.sum(np.abs(preds2 - preds)) == 0
def test_dmatrix_init(self): def test_dmatrix_init(self):
data = np.random.randn(5, 5) data = np.random.randn(5, 5)
@ -62,6 +63,7 @@ class TestBasic(unittest.TestCase):
def incorrect_type_set(): def incorrect_type_set():
dm.feature_types = list('abcde') dm.feature_types = list('abcde')
self.assertRaises(ValueError, incorrect_type_set) self.assertRaises(ValueError, incorrect_type_set)
# reset # reset
@ -83,10 +85,10 @@ class TestBasic(unittest.TestCase):
assert dm.num_row() == 100 assert dm.num_row() == 100
assert dm.num_col() == 5 assert dm.num_col() == 5
params={'objective': 'multi:softprob', params = {'objective': 'multi:softprob',
'eval_metric': 'mlogloss', 'eval_metric': 'mlogloss',
'eta': 0.3, 'eta': 0.3,
'num_class': 3} 'num_class': 3}
bst = xgb.train(params, dm, num_boost_round=10) bst = xgb.train(params, dm, num_boost_round=10)
scores = bst.get_fscore() scores = bst.get_fscore()
@ -143,9 +145,9 @@ class TestBasic(unittest.TestCase):
# 1 2 0 1 0 # 1 2 0 1 0
# 2 3 0 0 1 # 2 3 0 0 1
result, _, _ = xgb.core._maybe_pandas_data(dummies, None, None) result, _, _ = xgb.core._maybe_pandas_data(dummies, None, None)
exp = np.array([[ 1., 1., 0., 0.], exp = np.array([[1., 1., 0., 0.],
[ 2., 0., 1., 0.], [2., 0., 1., 0.],
[ 3., 0., 0., 1.]]) [3., 0., 0., 1.]])
np.testing.assert_array_equal(result, exp) np.testing.assert_array_equal(result, exp)
dm = xgb.DMatrix(dummies) dm = xgb.DMatrix(dummies)
@ -180,7 +182,6 @@ class TestBasic(unittest.TestCase):
assert dm.num_row() == 3 assert dm.num_row() == 3
assert dm.num_col() == 2 assert dm.num_col() == 2
def test_load_file_invalid(self): def test_load_file_invalid(self):
self.assertRaises(ValueError, xgb.Booster, self.assertRaises(ValueError, xgb.Booster,
@ -213,7 +214,7 @@ class TestBasic(unittest.TestCase):
def test_cv(self): def test_cv(self):
dm = xgb.DMatrix(dpath + 'agaricus.txt.train') dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
params = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic' } params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
import pandas as pd import pandas as pd
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10) cv = xgb.cv(params, dm, num_boost_round=10, nfold=10)
@ -241,6 +242,47 @@ class TestBasic(unittest.TestCase):
assert isinstance(cv, np.ndarray) assert isinstance(cv, np.ndarray)
assert cv.shape == (10, 4) assert cv.shape == (10, 4)
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic', 'eval_metric': 'auc'}
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=True)
assert 'eval_metric' in params
assert 'auc' in cv.columns[0]
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic', 'eval_metric': ['auc']}
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=True)
assert 'eval_metric' in params
assert 'auc' in cv.columns[0]
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic', 'eval_metric': ['auc']}
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=True, early_stopping_rounds=1)
assert 'eval_metric' in params
assert 'auc' in cv.columns[0]
assert cv.shape[0] < 10
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=True, metrics='auc')
assert 'auc' in cv.columns[0]
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=True, metrics=['auc'])
assert 'auc' in cv.columns[0]
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic', 'eval_metric': ['auc']}
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=True, metrics='error')
assert 'eval_metric' in params
assert 'auc' not in cv.columns[0]
assert 'error' in cv.columns[0]
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=True, metrics=['error'])
assert 'eval_metric' in params
assert 'auc' not in cv.columns[0]
assert 'error' in cv.columns[0]
params = list(params.items())
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=True, metrics=['error'])
assert isinstance(params, list)
assert 'auc' not in cv.columns[0]
assert 'error' in cv.columns[0]
def test_plotting(self): def test_plotting(self):
bst2 = xgb.Booster(model_file='xgb.model') bst2 = xgb.Booster(model_file='xgb.model')
# plotting # plotting
@ -263,7 +305,7 @@ class TestBasic(unittest.TestCase):
assert ax.get_ylabel() == 'y' assert ax.get_ylabel() == 'y'
assert len(ax.patches) == 4 assert len(ax.patches) == 4
for p in ax.patches: for p in ax.patches:
assert p.get_facecolor() == (1.0, 0, 0, 1.0) # red assert p.get_facecolor() == (1.0, 0, 0, 1.0) # red
ax = xgb.plot_importance(bst2, color=['r', 'r', 'b', 'b'], ax = xgb.plot_importance(bst2, color=['r', 'r', 'b', 'b'],
title=None, xlabel=None, ylabel=None) title=None, xlabel=None, ylabel=None)
@ -272,10 +314,10 @@ class TestBasic(unittest.TestCase):
assert ax.get_xlabel() == '' assert ax.get_xlabel() == ''
assert ax.get_ylabel() == '' assert ax.get_ylabel() == ''
assert len(ax.patches) == 4 assert len(ax.patches) == 4
assert ax.patches[0].get_facecolor() == (1.0, 0, 0, 1.0) # red assert ax.patches[0].get_facecolor() == (1.0, 0, 0, 1.0) # red
assert ax.patches[1].get_facecolor() == (1.0, 0, 0, 1.0) # red assert ax.patches[1].get_facecolor() == (1.0, 0, 0, 1.0) # red
assert ax.patches[2].get_facecolor() == (0, 0, 1.0, 1.0) # blue assert ax.patches[2].get_facecolor() == (0, 0, 1.0, 1.0) # blue
assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # blue assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # blue
g = xgb.to_graphviz(bst2, num_trees=0) g = xgb.to_graphviz(bst2, num_trees=0)
assert isinstance(g, Digraph) assert isinstance(g, Digraph)
@ -285,7 +327,7 @@ class TestBasic(unittest.TestCase):
def test_importance_plot_lim(self): def test_importance_plot_lim(self):
np.random.seed(1) np.random.seed(1)
dm = xgb.DMatrix(np.random.randn(100, 100), label=[0, 1]*50) dm = xgb.DMatrix(np.random.randn(100, 100), label=[0, 1] * 50)
bst = xgb.train({}, dm) bst = xgb.train({}, dm)
assert len(bst.get_fscore()) == 71 assert len(bst.get_fscore()) == 71
ax = xgb.plot_importance(bst) ax = xgb.plot_importance(bst)