Update Python custom objective demo. (#5981)
This commit is contained in:
@@ -197,9 +197,9 @@ class TestModels(unittest.TestCase):
|
||||
assert np.all(np.abs(predt_2 - predt_1) < 1e-6)
|
||||
|
||||
def test_custom_objective(self):
|
||||
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0}
|
||||
param = {'max_depth': 2, 'eta': 1, 'objective': 'reg:logistic'}
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
num_round = 2
|
||||
num_round = 10
|
||||
|
||||
def logregobj(preds, dtrain):
|
||||
labels = dtrain.get_label()
|
||||
@@ -210,10 +210,12 @@ class TestModels(unittest.TestCase):
|
||||
|
||||
def evalerror(preds, dtrain):
|
||||
labels = dtrain.get_label()
|
||||
preds = 1.0 / (1.0 + np.exp(-preds))
|
||||
return 'error', float(sum(labels != (preds > 0.5))) / len(labels)
|
||||
|
||||
# test custom_objective in training
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist, logregobj, evalerror)
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist, obj=logregobj,
|
||||
feval=evalerror)
|
||||
assert isinstance(bst, xgb.core.Booster)
|
||||
preds = bst.predict(dtest)
|
||||
labels = dtest.get_label()
|
||||
@@ -230,7 +232,8 @@ class TestModels(unittest.TestCase):
|
||||
labels = dtrain.get_label()
|
||||
return 'error', float(sum(labels == (preds > 0.0))) / len(labels)
|
||||
|
||||
bst2 = xgb.train(param, dtrain, num_round, watchlist, logregobj, neg_evalerror, maximize=True)
|
||||
bst2 = xgb.train(param, dtrain, num_round, watchlist, logregobj,
|
||||
neg_evalerror, maximize=True)
|
||||
preds2 = bst2.predict(dtest)
|
||||
err2 = sum(1 for i in range(len(preds2))
|
||||
if int(preds2[i] > 0.5) != labels[i]) / float(len(preds2))
|
||||
|
||||
Reference in New Issue
Block a user