diff --git a/tests/python/test_models.py b/tests/python/test_models.py index 8c06d9de9..2308b1229 100644 --- a/tests/python/test_models.py +++ b/tests/python/test_models.py @@ -36,4 +36,13 @@ def test_custom_objective(): err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds)) assert err < 0.1 - +def test_fpreproc(): + param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic'} + num_round = 2 + def fpreproc(dtrain, dtest, param): + label = dtrain.get_label() + ratio = float(np.sum(label == 0)) / np.sum(label==1) + param['scale_pos_weight'] = ratio + return (dtrain, dtest, param) + xgb.cv(param, dtrain, num_round, nfold=5, + metrics={'auc'}, seed = 0, fpreproc = fpreproc)