Define best_iteration only if early stopping is used. (#9403)
* Define `best_iteration` only if early stopping is used. This is the behavior specified by the document but not honored in the actual code. - Don't set the attributes if there's no early stopping. - Clean up the code for callbacks, and replace assertions with proper exceptions. - Assign the attributes when early stopping `save_best` is used. - Turn the attributes into Python properties. --------- Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -100,8 +100,8 @@ class TestTrainingContinuation:
|
||||
res2 = mean_squared_error(
|
||||
y_2class,
|
||||
gbdt_04.predict(
|
||||
dtrain_2class, iteration_range=(0, gbdt_04.best_iteration + 1)
|
||||
)
|
||||
dtrain_2class, iteration_range=(0, gbdt_04.num_boosted_rounds())
|
||||
),
|
||||
)
|
||||
assert res1 == res2
|
||||
|
||||
@@ -112,7 +112,7 @@ class TestTrainingContinuation:
|
||||
res2 = mean_squared_error(
|
||||
y_2class,
|
||||
gbdt_04.predict(
|
||||
dtrain_2class, iteration_range=(0, gbdt_04.best_iteration + 1)
|
||||
dtrain_2class, iteration_range=(0, gbdt_04.num_boosted_rounds())
|
||||
)
|
||||
)
|
||||
assert res1 == res2
|
||||
@@ -126,7 +126,7 @@ class TestTrainingContinuation:
|
||||
|
||||
res1 = gbdt_05.predict(dtrain_5class)
|
||||
res2 = gbdt_05.predict(
|
||||
dtrain_5class, iteration_range=(0, gbdt_05.best_iteration + 1)
|
||||
dtrain_5class, iteration_range=(0, gbdt_05.num_boosted_rounds())
|
||||
)
|
||||
np.testing.assert_almost_equal(res1, res2)
|
||||
|
||||
@@ -138,15 +138,16 @@ class TestTrainingContinuation:
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_training_continuation_updaters_json(self):
|
||||
# Picked up from R tests.
|
||||
updaters = 'grow_colmaker,prune,refresh'
|
||||
updaters = "grow_colmaker,prune,refresh"
|
||||
params = self.generate_parameters()
|
||||
for p in params:
|
||||
p['updater'] = updaters
|
||||
p["updater"] = updaters
|
||||
self.run_training_continuation(params[0], params[1], params[2])
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_changed_parameter(self):
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
clf = xgb.XGBClassifier(n_estimators=2)
|
||||
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss")
|
||||
|
||||
Reference in New Issue
Block a user