This commit is contained in:
antinucleon 2014-09-03 12:57:05 -06:00
parent 998ca3bdc9
commit 0c36231ea3

View File

@ -438,17 +438,17 @@ def mknfold(dall, nfold, param, seed, evals=[], fpreproc = None):
mk nfold list of cvpack from randidx mk nfold list of cvpack from randidx
""" """
np.random.seed(seed) np.random.seed(seed)
randidx = np.random.permutation(dall.num_rows()) randidx = np.random.permutation(dall.num_row())
kstep = len(randidx) / nfold kstep = len(randidx) / nfold
idset = [randidx[ (i*kstep) : min(len(randidx),(i+1)*kstep) ] for i in range(nfold)] idset = [randidx[ (i*kstep) : min(len(randidx),(i+1)*kstep) ] for i in range(nfold)]
ret = [] ret = []
for k in range(nfold): for k in range(nfold):
dtrain = dall.slice(np.concatenate([idset[i] for i in range(nfold) if k != i])) dtrain = dall.slice(np.concatenate([idset[i] for i in range(nfold) if k != i]))
dtest = all.slice(idxset[k]) dtest = dall.slice(idset[k])
# run preprocessing on the data set if needed # run preprocessing on the data set if needed
if fpreproc is not None: if fpreproc is not None:
dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy()) dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy())
plst = tparam.items() + [('eval_metric', itm) for itm in evals] plst = param.items() + [('eval_metric', itm) for itm in evals]
ret.append(CVPack(dtrain, dtest, plst)) ret.append(CVPack(dtrain, dtest, plst))
return ret return ret
@ -490,7 +490,7 @@ def cv(params, dtrain, num_boost_round = 10, nfold=3, eval_metric = [], \
fpreproc: preprocessing function that takes dtrain, dtest, fpreproc: preprocessing function that takes dtrain, dtest,
param and return transformed version of dtrain, dtest, param param and return transformed version of dtrain, dtest, param
""" """
cvfolds = mknfold(dtrain, nfold, params, 0, eval_metrics, fpreproc) cvfolds = mknfold(dtrain, nfold, params, 0, eval_metric, fpreproc)
for i in range(num_boost_round): for i in range(num_boost_round):
for f in cvfolds: for f in cvfolds:
f.update(i, obj) f.update(i, obj)