Remove old callback deprecated in 1.3. (#7280)

This commit is contained in:
Jiaming Yuan
2021-10-08 17:24:59 +08:00
committed by GitHub
parent 578de9f762
commit 69d3b1b8b4
7 changed files with 70 additions and 475 deletions

View File

@@ -22,15 +22,25 @@ def test_aft_survival_toy_data():
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes
# the corresponding predicted label (y_pred)
acc_rec = []
def my_callback(env):
y_pred = env.model.predict(dmat)
acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X))
acc_rec.append(acc)
class Callback(xgb.callback.TrainingCallback):
def __init__(self):
super().__init__()
def after_iteration(
self, model: xgb.Booster,
epoch: int,
evals_log: xgb.callback.TrainingCallback.EvalsLog
):
y_pred = model.predict(dmat)
acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X))
acc_rec.append(acc)
return False
evals_result = {}
params = {'max_depth': 3, 'objective':'survival:aft', 'min_child_weight': 0}
params = {'max_depth': 3, 'objective': 'survival:aft', 'min_child_weight': 0}
bst = xgb.train(params, dmat, 15, [(dmat, 'train')], evals_result=evals_result,
callbacks=[my_callback])
callbacks=[Callback()])
nloglik_rec = evals_result['train']['aft-nloglik']
# AFT metric (negative log likelihood) improve monotonically