TST: Added test for custom_objective function in cv
This commit is contained in:
parent
dfb89e3442
commit
1411d3f37f
@ -29,6 +29,8 @@ def test_custom_objective():
|
||||
def evalerror(preds, dtrain):
|
||||
labels = dtrain.get_label()
|
||||
return 'error', float(sum(labels != (preds > 0.0))) / len(labels)
|
||||
|
||||
# test custom_objective in training
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist, logregobj, evalerror)
|
||||
assert isinstance(bst, xgb.core.Booster)
|
||||
preds = bst.predict(dtest)
|
||||
@ -36,6 +38,10 @@ def test_custom_objective():
|
||||
err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds))
|
||||
assert err < 0.1
|
||||
|
||||
# test custom_objective in cross-validation
|
||||
xgb.cv(param, dtrain, num_round, nfold = 5, seed = 0,
|
||||
obj = logregobj, feval=evalerror)
|
||||
|
||||
def test_fpreproc():
|
||||
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic'}
|
||||
num_round = 2
|
||||
@ -53,7 +59,7 @@ def test_show_stdv():
|
||||
xgb.cv(param, dtrain, num_round, nfold=5,
|
||||
metrics={'error'}, seed = 0, show_stdv = False)
|
||||
|
||||
|
||||
test_custom_objective()
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user