Remove old callback deprecated in 1.3. (#7280)
This commit is contained in:
@@ -22,15 +22,25 @@ def test_aft_survival_toy_data():
|
||||
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes
|
||||
# the corresponding predicted label (y_pred)
|
||||
acc_rec = []
|
||||
def my_callback(env):
|
||||
y_pred = env.model.predict(dmat)
|
||||
acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X))
|
||||
acc_rec.append(acc)
|
||||
|
||||
class Callback(xgb.callback.TrainingCallback):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def after_iteration(
|
||||
self, model: xgb.Booster,
|
||||
epoch: int,
|
||||
evals_log: xgb.callback.TrainingCallback.EvalsLog
|
||||
):
|
||||
y_pred = model.predict(dmat)
|
||||
acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X))
|
||||
acc_rec.append(acc)
|
||||
return False
|
||||
|
||||
evals_result = {}
|
||||
params = {'max_depth': 3, 'objective':'survival:aft', 'min_child_weight': 0}
|
||||
params = {'max_depth': 3, 'objective': 'survival:aft', 'min_child_weight': 0}
|
||||
bst = xgb.train(params, dmat, 15, [(dmat, 'train')], evals_result=evals_result,
|
||||
callbacks=[my_callback])
|
||||
callbacks=[Callback()])
|
||||
|
||||
nloglik_rec = evals_result['train']['aft-nloglik']
|
||||
# AFT metric (negative log likelihood) improve monotonically
|
||||
|
||||
Reference in New Issue
Block a user