diff --git a/tests/python/test_models.py b/tests/python/test_models.py index 9fc4d7472..6842a67b6 100644 --- a/tests/python/test_models.py +++ b/tests/python/test_models.py @@ -29,6 +29,8 @@ def test_custom_objective(): def evalerror(preds, dtrain): labels = dtrain.get_label() return 'error', float(sum(labels != (preds > 0.0))) / len(labels) + + # test custom_objective in training bst = xgb.train(param, dtrain, num_round, watchlist, logregobj, evalerror) assert isinstance(bst, xgb.core.Booster) preds = bst.predict(dtest) @@ -36,6 +38,10 @@ def test_custom_objective(): err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds)) assert err < 0.1 + # test custom_objective in cross-validation + xgb.cv(param, dtrain, num_round, nfold = 5, seed = 0, + obj = logregobj, feval=evalerror) + def test_fpreproc(): param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic'} num_round = 2 @@ -53,7 +59,7 @@ def test_show_stdv(): xgb.cv(param, dtrain, num_round, nfold=5, metrics={'error'}, seed = 0, show_stdv = False) - +test_custom_objective()