add cv for python

This commit is contained in:
tqchen
2014-09-03 22:43:55 -07:00
parent 586d6ae740
commit da9c856701
6 changed files with 91 additions and 10 deletions

View File

@@ -448,11 +448,13 @@ def mknfold(dall, nfold, param, seed, evals=[], fpreproc = None):
# run preprocessing on the data set if needed
if fpreproc is not None:
dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy())
else:
tparam = param
plst = tparam.items() + [('eval_metric', itm) for itm in evals]
ret.append(CVPack(dtrain, dtest, plst))
return ret
def aggcv(rlist):
def aggcv(rlist, show_stdv=True):
"""
aggregate cross validation results
"""
@@ -468,11 +470,14 @@ def aggcv(rlist):
cvmap[k].append(float(v))
for k, v in sorted(cvmap.items(), key = lambda x:x[0]):
v = np.array(v)
ret += '\t%s:%f+%f' % (k, np.mean(v), np.std(v))
if show_stdv:
ret += '\tcv-%s:%f+%f' % (k, np.mean(v), np.std(v))
else:
ret += '\tcv-%s:%f' % (k, np.mean(v))
return ret
def cv(params, dtrain, num_boost_round = 10, nfold=3, eval_metric = [], \
obj = None, feval = None, fpreproc = None):
def cv(params, dtrain, num_boost_round = 10, nfold=3, metrics=[], \
obj = None, feval = None, fpreproc = None, show_stdv = True, seed = 0):
""" cross validation with given paramaters
Args:
params: dict
@@ -485,14 +490,21 @@ def cv(params, dtrain, num_boost_round = 10, nfold=3, eval_metric = [], \
folds to do cv
evals: list or
list of items to be evaluated
obj:
feval:
obj: custom objective function
feval: custom evaluation function
fpreproc: preprocessing function that takes dtrain, dtest,
param and return transformed version of dtrain, dtest, param
show_stdv: whether display standard deviation
seed: seed used to generate the folds
Returns: list(string) of evaluation history
"""
cvfolds = mknfold(dtrain, nfold, params, 0, eval_metric, fpreproc)
results = []
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc)
for i in range(num_boost_round):
for f in cvfolds:
f.update(i, obj)
res = aggcv([f.eval(i, feval) for f in cvfolds])
res = aggcv([f.eval(i, feval) for f in cvfolds], show_stdv)
sys.stderr.write(res+'\n')
results.append(res)
return results