CV returns ndarray or DataFrame
This commit is contained in:
parent
db490d1c75
commit
b958c55ac6
@ -1,5 +1,6 @@
|
||||
# coding: utf-8
|
||||
# pylint: disable=too-many-locals, too-many-arguments, invalid-name
|
||||
# pylint: disable=too-many-branches
|
||||
"""Training Library containing training routines."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
@ -179,16 +180,16 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None):
|
||||
return ret
|
||||
|
||||
|
||||
def aggcv(rlist, show_stdv=True):
|
||||
def aggcv(rlist, show_stdv=True, show_progress=None, as_pandas=True):
|
||||
# pylint: disable=invalid-name
|
||||
"""
|
||||
Aggregate cross-validation results.
|
||||
"""
|
||||
cvmap = {}
|
||||
ret = rlist[0].split()[0]
|
||||
idx = rlist[0].split()[0]
|
||||
for line in rlist:
|
||||
arr = line.split()
|
||||
assert ret == arr[0]
|
||||
assert idx == arr[0]
|
||||
for it in arr[1:]:
|
||||
if not isinstance(it, STRING_TYPES):
|
||||
it = it.decode()
|
||||
@ -196,19 +197,50 @@ def aggcv(rlist, show_stdv=True):
|
||||
if k not in cvmap:
|
||||
cvmap[k] = []
|
||||
cvmap[k].append(float(v))
|
||||
|
||||
msg = idx
|
||||
|
||||
if show_stdv:
|
||||
fmt = '\tcv-{0}:{1}+{2}'
|
||||
else:
|
||||
fmt = '\tcv-{0}:{1}'
|
||||
|
||||
index = []
|
||||
results = []
|
||||
for k, v in sorted(cvmap.items(), key=lambda x: x[0]):
|
||||
v = np.array(v)
|
||||
if not isinstance(ret, STRING_TYPES):
|
||||
ret = ret.decode()
|
||||
if show_stdv:
|
||||
ret += '\tcv-%s:%f+%f' % (k, np.mean(v), np.std(v))
|
||||
if not isinstance(msg, STRING_TYPES):
|
||||
msg = msg.decode()
|
||||
mean, std = np.mean(v), np.std(v)
|
||||
msg += fmt.format(k, mean, std)
|
||||
|
||||
index.extend([k + '-mean', k + '-std'])
|
||||
results.extend([mean, std])
|
||||
|
||||
|
||||
|
||||
if as_pandas:
|
||||
try:
|
||||
import pandas as pd
|
||||
results = pd.Series(results, index=index)
|
||||
except ImportError:
|
||||
if show_progress is None:
|
||||
show_progress = True
|
||||
else:
|
||||
ret += '\tcv-%s:%f' % (k, np.mean(v))
|
||||
return ret
|
||||
# if show_progress is default (None),
|
||||
# result will be np.ndarray as it can't hold column name
|
||||
if show_progress is None:
|
||||
show_progress = True
|
||||
|
||||
if show_progress:
|
||||
sys.stderr.write(msg + '\n')
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
|
||||
obj=None, feval=None, fpreproc=None, show_stdv=True, seed=0):
|
||||
obj=None, feval=None, fpreproc=None, as_pandas=True,
|
||||
show_progress=None, show_stdv=True, seed=0):
|
||||
# pylint: disable = invalid-name
|
||||
"""Cross-validation with given paramaters.
|
||||
|
||||
@ -231,8 +263,15 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
|
||||
fpreproc : function
|
||||
Preprocessing function that takes (dtrain, dtest, param) and returns
|
||||
transformed versions of those.
|
||||
show_stdv : bool
|
||||
Whether to display the standard deviation.
|
||||
as_pandas : bool, default True
|
||||
Return pd.DataFrame when pandas is installed.
|
||||
If False or pandas is not installed, return np.ndarray
|
||||
show_progress : bool or None, default None
|
||||
Whether to display the progress. If None, progress will be displayed
|
||||
when np.ndarray is returned.
|
||||
show_stdv : bool, default True
|
||||
Whether to display the standard deviation in progress.
|
||||
Results are not affected, and always contains std.
|
||||
seed : int
|
||||
Seed used to generate the folds (passed to numpy.random.seed).
|
||||
|
||||
@ -245,8 +284,19 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
|
||||
for i in range(num_boost_round):
|
||||
for fold in cvfolds:
|
||||
fold.update(i, obj)
|
||||
res = aggcv([f.eval(i, feval) for f in cvfolds], show_stdv)
|
||||
sys.stderr.write(res + '\n')
|
||||
res = aggcv([f.eval(i, feval) for f in cvfolds],
|
||||
show_stdv=show_stdv, show_progress=show_progress,
|
||||
as_pandas=as_pandas)
|
||||
results.append(res)
|
||||
|
||||
if as_pandas:
|
||||
try:
|
||||
import pandas as pd
|
||||
results = pd.DataFrame(results)
|
||||
except ImportError:
|
||||
results = np.array(results)
|
||||
else:
|
||||
results = np.array(results)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@ -64,7 +64,7 @@ if [ ${TASK} == "python-package" -o ${TASK} == "python-package3" ]; then
|
||||
conda create -n myenv python=2.7
|
||||
fi
|
||||
source activate myenv
|
||||
conda install numpy scipy matplotlib nose
|
||||
conda install numpy scipy pandas matplotlib nose
|
||||
python -m pip install graphviz
|
||||
|
||||
make all CXX=${CXX} || exit -1
|
||||
|
||||
@ -127,6 +127,36 @@ class TestBasic(unittest.TestCase):
|
||||
data = np.array([['a', 'b'], ['c', 'd']])
|
||||
self.assertRaises(ValueError, xgb.DMatrix, data)
|
||||
|
||||
def test_cv(self):
|
||||
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
params = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic' }
|
||||
|
||||
import pandas as pd
|
||||
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10)
|
||||
assert isinstance(cv, pd.DataFrame)
|
||||
exp = pd.Index([u'test-error-mean', u'test-error-std',
|
||||
u'train-error-mean', u'train-error-std'])
|
||||
assert cv.columns.equals(exp)
|
||||
|
||||
# show progress log (result is the same as above)
|
||||
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10,
|
||||
show_progress=True)
|
||||
assert isinstance(cv, pd.DataFrame)
|
||||
exp = pd.Index([u'test-error-mean', u'test-error-std',
|
||||
u'train-error-mean', u'train-error-std'])
|
||||
assert cv.columns.equals(exp)
|
||||
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10,
|
||||
show_progress=True, show_stdv=False)
|
||||
assert isinstance(cv, pd.DataFrame)
|
||||
exp = pd.Index([u'test-error-mean', u'test-error-std',
|
||||
u'train-error-mean', u'train-error-std'])
|
||||
assert cv.columns.equals(exp)
|
||||
|
||||
# return np.ndarray
|
||||
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=False)
|
||||
assert isinstance(cv, np.ndarray)
|
||||
assert cv.shape == (10, 4)
|
||||
|
||||
def test_plotting(self):
|
||||
bst2 = xgb.Booster(model_file='xgb.model')
|
||||
# plotting
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user