CV returns ndarray or DataFrame
This commit is contained in:
parent
db490d1c75
commit
b958c55ac6
@ -1,5 +1,6 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
# pylint: disable=too-many-locals, too-many-arguments, invalid-name
|
# pylint: disable=too-many-locals, too-many-arguments, invalid-name
|
||||||
|
# pylint: disable=too-many-branches
|
||||||
"""Training Library containing training routines."""
|
"""Training Library containing training routines."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
@ -179,16 +180,16 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def aggcv(rlist, show_stdv=True):
|
def aggcv(rlist, show_stdv=True, show_progress=None, as_pandas=True):
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
"""
|
"""
|
||||||
Aggregate cross-validation results.
|
Aggregate cross-validation results.
|
||||||
"""
|
"""
|
||||||
cvmap = {}
|
cvmap = {}
|
||||||
ret = rlist[0].split()[0]
|
idx = rlist[0].split()[0]
|
||||||
for line in rlist:
|
for line in rlist:
|
||||||
arr = line.split()
|
arr = line.split()
|
||||||
assert ret == arr[0]
|
assert idx == arr[0]
|
||||||
for it in arr[1:]:
|
for it in arr[1:]:
|
||||||
if not isinstance(it, STRING_TYPES):
|
if not isinstance(it, STRING_TYPES):
|
||||||
it = it.decode()
|
it = it.decode()
|
||||||
@ -196,19 +197,50 @@ def aggcv(rlist, show_stdv=True):
|
|||||||
if k not in cvmap:
|
if k not in cvmap:
|
||||||
cvmap[k] = []
|
cvmap[k] = []
|
||||||
cvmap[k].append(float(v))
|
cvmap[k].append(float(v))
|
||||||
|
|
||||||
|
msg = idx
|
||||||
|
|
||||||
|
if show_stdv:
|
||||||
|
fmt = '\tcv-{0}:{1}+{2}'
|
||||||
|
else:
|
||||||
|
fmt = '\tcv-{0}:{1}'
|
||||||
|
|
||||||
|
index = []
|
||||||
|
results = []
|
||||||
for k, v in sorted(cvmap.items(), key=lambda x: x[0]):
|
for k, v in sorted(cvmap.items(), key=lambda x: x[0]):
|
||||||
v = np.array(v)
|
v = np.array(v)
|
||||||
if not isinstance(ret, STRING_TYPES):
|
if not isinstance(msg, STRING_TYPES):
|
||||||
ret = ret.decode()
|
msg = msg.decode()
|
||||||
if show_stdv:
|
mean, std = np.mean(v), np.std(v)
|
||||||
ret += '\tcv-%s:%f+%f' % (k, np.mean(v), np.std(v))
|
msg += fmt.format(k, mean, std)
|
||||||
else:
|
|
||||||
ret += '\tcv-%s:%f' % (k, np.mean(v))
|
index.extend([k + '-mean', k + '-std'])
|
||||||
return ret
|
results.extend([mean, std])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if as_pandas:
|
||||||
|
try:
|
||||||
|
import pandas as pd
|
||||||
|
results = pd.Series(results, index=index)
|
||||||
|
except ImportError:
|
||||||
|
if show_progress is None:
|
||||||
|
show_progress = True
|
||||||
|
else:
|
||||||
|
# if show_progress is default (None),
|
||||||
|
# result will be np.ndarray as it can't hold column name
|
||||||
|
if show_progress is None:
|
||||||
|
show_progress = True
|
||||||
|
|
||||||
|
if show_progress:
|
||||||
|
sys.stderr.write(msg + '\n')
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
|
def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
|
||||||
obj=None, feval=None, fpreproc=None, show_stdv=True, seed=0):
|
obj=None, feval=None, fpreproc=None, as_pandas=True,
|
||||||
|
show_progress=None, show_stdv=True, seed=0):
|
||||||
# pylint: disable = invalid-name
|
# pylint: disable = invalid-name
|
||||||
"""Cross-validation with given paramaters.
|
"""Cross-validation with given paramaters.
|
||||||
|
|
||||||
@ -231,8 +263,15 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
|
|||||||
fpreproc : function
|
fpreproc : function
|
||||||
Preprocessing function that takes (dtrain, dtest, param) and returns
|
Preprocessing function that takes (dtrain, dtest, param) and returns
|
||||||
transformed versions of those.
|
transformed versions of those.
|
||||||
show_stdv : bool
|
as_pandas : bool, default True
|
||||||
Whether to display the standard deviation.
|
Return pd.DataFrame when pandas is installed.
|
||||||
|
If False or pandas is not installed, return np.ndarray
|
||||||
|
show_progress : bool or None, default None
|
||||||
|
Whether to display the progress. If None, progress will be displayed
|
||||||
|
when np.ndarray is returned.
|
||||||
|
show_stdv : bool, default True
|
||||||
|
Whether to display the standard deviation in progress.
|
||||||
|
Results are not affected, and always contains std.
|
||||||
seed : int
|
seed : int
|
||||||
Seed used to generate the folds (passed to numpy.random.seed).
|
Seed used to generate the folds (passed to numpy.random.seed).
|
||||||
|
|
||||||
@ -245,8 +284,19 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
|
|||||||
for i in range(num_boost_round):
|
for i in range(num_boost_round):
|
||||||
for fold in cvfolds:
|
for fold in cvfolds:
|
||||||
fold.update(i, obj)
|
fold.update(i, obj)
|
||||||
res = aggcv([f.eval(i, feval) for f in cvfolds], show_stdv)
|
res = aggcv([f.eval(i, feval) for f in cvfolds],
|
||||||
sys.stderr.write(res + '\n')
|
show_stdv=show_stdv, show_progress=show_progress,
|
||||||
|
as_pandas=as_pandas)
|
||||||
results.append(res)
|
results.append(res)
|
||||||
|
|
||||||
|
if as_pandas:
|
||||||
|
try:
|
||||||
|
import pandas as pd
|
||||||
|
results = pd.DataFrame(results)
|
||||||
|
except ImportError:
|
||||||
|
results = np.array(results)
|
||||||
|
else:
|
||||||
|
results = np.array(results)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|||||||
@ -64,7 +64,7 @@ if [ ${TASK} == "python-package" -o ${TASK} == "python-package3" ]; then
|
|||||||
conda create -n myenv python=2.7
|
conda create -n myenv python=2.7
|
||||||
fi
|
fi
|
||||||
source activate myenv
|
source activate myenv
|
||||||
conda install numpy scipy matplotlib nose
|
conda install numpy scipy pandas matplotlib nose
|
||||||
python -m pip install graphviz
|
python -m pip install graphviz
|
||||||
|
|
||||||
make all CXX=${CXX} || exit -1
|
make all CXX=${CXX} || exit -1
|
||||||
|
|||||||
@ -127,6 +127,36 @@ class TestBasic(unittest.TestCase):
|
|||||||
data = np.array([['a', 'b'], ['c', 'd']])
|
data = np.array([['a', 'b'], ['c', 'd']])
|
||||||
self.assertRaises(ValueError, xgb.DMatrix, data)
|
self.assertRaises(ValueError, xgb.DMatrix, data)
|
||||||
|
|
||||||
|
def test_cv(self):
|
||||||
|
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
|
params = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic' }
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10)
|
||||||
|
assert isinstance(cv, pd.DataFrame)
|
||||||
|
exp = pd.Index([u'test-error-mean', u'test-error-std',
|
||||||
|
u'train-error-mean', u'train-error-std'])
|
||||||
|
assert cv.columns.equals(exp)
|
||||||
|
|
||||||
|
# show progress log (result is the same as above)
|
||||||
|
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10,
|
||||||
|
show_progress=True)
|
||||||
|
assert isinstance(cv, pd.DataFrame)
|
||||||
|
exp = pd.Index([u'test-error-mean', u'test-error-std',
|
||||||
|
u'train-error-mean', u'train-error-std'])
|
||||||
|
assert cv.columns.equals(exp)
|
||||||
|
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10,
|
||||||
|
show_progress=True, show_stdv=False)
|
||||||
|
assert isinstance(cv, pd.DataFrame)
|
||||||
|
exp = pd.Index([u'test-error-mean', u'test-error-std',
|
||||||
|
u'train-error-mean', u'train-error-std'])
|
||||||
|
assert cv.columns.equals(exp)
|
||||||
|
|
||||||
|
# return np.ndarray
|
||||||
|
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=False)
|
||||||
|
assert isinstance(cv, np.ndarray)
|
||||||
|
assert cv.shape == (10, 4)
|
||||||
|
|
||||||
def test_plotting(self):
|
def test_plotting(self):
|
||||||
bst2 = xgb.Booster(model_file='xgb.model')
|
bst2 = xgb.Booster(model_file='xgb.model')
|
||||||
# plotting
|
# plotting
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user