TST: Added test for custom_objective function in cv

This commit is contained in:
terrytangyuan 2015-10-04 22:45:10 -05:00
parent dfb89e3442
commit 1411d3f37f

View File

@ -29,6 +29,8 @@ def test_custom_objective():
def evalerror(preds, dtrain): def evalerror(preds, dtrain):
labels = dtrain.get_label() labels = dtrain.get_label()
return 'error', float(sum(labels != (preds > 0.0))) / len(labels) return 'error', float(sum(labels != (preds > 0.0))) / len(labels)
# test custom_objective in training
bst = xgb.train(param, dtrain, num_round, watchlist, logregobj, evalerror) bst = xgb.train(param, dtrain, num_round, watchlist, logregobj, evalerror)
assert isinstance(bst, xgb.core.Booster) assert isinstance(bst, xgb.core.Booster)
preds = bst.predict(dtest) preds = bst.predict(dtest)
@ -36,6 +38,10 @@ def test_custom_objective():
err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds)) err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds))
assert err < 0.1 assert err < 0.1
# test custom_objective in cross-validation
xgb.cv(param, dtrain, num_round, nfold = 5, seed = 0,
obj = logregobj, feval=evalerror)
def test_fpreproc(): def test_fpreproc():
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic'} param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic'}
num_round = 2 num_round = 2
@ -53,7 +59,7 @@ def test_show_stdv():
xgb.cv(param, dtrain, num_round, nfold=5, xgb.cv(param, dtrain, num_round, nfold=5,
metrics={'error'}, seed = 0, show_stdv = False) metrics={'error'}, seed = 0, show_stdv = False)
test_custom_objective()