TST: Added test for custom_objective function in cv
This commit is contained in:
parent
dfb89e3442
commit
1411d3f37f
@ -29,6 +29,8 @@ def test_custom_objective():
|
|||||||
def evalerror(preds, dtrain):
|
def evalerror(preds, dtrain):
|
||||||
labels = dtrain.get_label()
|
labels = dtrain.get_label()
|
||||||
return 'error', float(sum(labels != (preds > 0.0))) / len(labels)
|
return 'error', float(sum(labels != (preds > 0.0))) / len(labels)
|
||||||
|
|
||||||
|
# test custom_objective in training
|
||||||
bst = xgb.train(param, dtrain, num_round, watchlist, logregobj, evalerror)
|
bst = xgb.train(param, dtrain, num_round, watchlist, logregobj, evalerror)
|
||||||
assert isinstance(bst, xgb.core.Booster)
|
assert isinstance(bst, xgb.core.Booster)
|
||||||
preds = bst.predict(dtest)
|
preds = bst.predict(dtest)
|
||||||
@ -36,6 +38,10 @@ def test_custom_objective():
|
|||||||
err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds))
|
err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds))
|
||||||
assert err < 0.1
|
assert err < 0.1
|
||||||
|
|
||||||
|
# test custom_objective in cross-validation
|
||||||
|
xgb.cv(param, dtrain, num_round, nfold = 5, seed = 0,
|
||||||
|
obj = logregobj, feval=evalerror)
|
||||||
|
|
||||||
def test_fpreproc():
|
def test_fpreproc():
|
||||||
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic'}
|
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic'}
|
||||||
num_round = 2
|
num_round = 2
|
||||||
@ -53,7 +59,7 @@ def test_show_stdv():
|
|||||||
xgb.cv(param, dtrain, num_round, nfold=5,
|
xgb.cv(param, dtrain, num_round, nfold=5,
|
||||||
metrics={'error'}, seed = 0, show_stdv = False)
|
metrics={'error'}, seed = 0, show_stdv = False)
|
||||||
|
|
||||||
|
test_custom_objective()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user