This commit is contained in:
antinucleon 2014-09-03 00:37:55 -06:00
parent 06b5533209
commit 02dd8d1212

View File

@ -430,10 +430,10 @@ class CVPack:
self.bst = Booster(param, [dtrain,dtest]) self.bst = Booster(param, [dtrain,dtest])
def update(self, r, fobj): def update(self, r, fobj):
self.bst.update(self.dtrain, r, fobj) self.bst.update(self.dtrain, r, fobj)
def eval(self, r, fval): def eval(self, r, feval):
return self.bst.eval_set(self.watchlist, r, feval) return self.bst.eval_set(self.watchlist, r, feval)
def mknfold(dall, nfold, param, seed, weightscale=None, evals=[]): def mknfold(dall, nfold, param, seed, evals=[]):
""" """
mk nfold list of cvpack from randidx mk nfold list of cvpack from randidx
""" """
@ -457,9 +457,6 @@ def mknfold(dall, nfold, param, seed, weightscale=None, evals=[]):
dtrain = dall.slice(trainlst) dtrain = dall.slice(trainlst)
dtest = dall.slice(testlst) dtest = dall.slice(testlst)
# rescale weight of dtrain and dtest # rescale weight of dtrain and dtest
if weightscale != None:
dtrain.set_weight( dtrain.get_weight() * weightscale * dall.num_row() / dtrain.num_row() )
dtest.set_weight( dtest.get_weight() * weightscale * dall.num_row() / dtest.num_row() )
plst = param.items() + [('eval_metric', itm) for itm in evals] plst = param.items() + [('eval_metric', itm) for itm in evals]
ret.append(CVPack(dtrain, dtest, plst)) ret.append(CVPack(dtrain, dtest, plst))
return ret return ret
@ -487,7 +484,7 @@ def aggcv(rlist):
return ret return ret
def cv(params, dtrain, num_boost_round = 10, nfold=3, eval_metrics = [], \ def cv(params, dtrain, num_boost_round = 10, nfold=3, eval_metrics = [], \
weightscale=None, obj=None, feval=None): obj=None, feval=None):
""" cross validation with given paramaters """ cross validation with given paramaters
Args: Args:
params: dict params: dict
@ -503,9 +500,9 @@ def cv(params, dtrain, num_boost_round = 10, nfold=3, eval_metrics = [], \
obj: obj:
feval: feval:
""" """
cvfolds = mknfold(dtrain, nfold, params, 0, weightscale, evals_metrics) cvfolds = mknfold(dtrain, nfold, params, 0, eval_metrics)
for i in range(num_boost_round): for i in range(num_boost_round):
for f in cvfolds: for f in cvfolds:
f.update(i, obj) f.update(i, obj)
res = aggcv([f.eval(i, fval) for f in cvfolds]) res = aggcv([f.eval(i, feval) for f in cvfolds])
sys.stderr.write(res+'\n') sys.stderr.write(res+'\n')