Fix learning rate scheduler with cv. (#6720)

* Expose more methods in cvpack and packed booster.
* Fix cv context in deprecated callbacks.
* Fix document.
This commit is contained in:
Jiaming Yuan
2021-02-28 13:57:42 +08:00
committed by GitHub
parent 9c8523432a
commit a9b4a95225
3 changed files with 46 additions and 20 deletions

View File

@@ -206,6 +206,7 @@ class TestCallbacks:
booster.best_iteration + early_stopping_rounds + 1
def run_eta_decay(self, tree_method, deprecated_callback):
"""Test learning rate scheduler, used by both CPU and GPU tests."""
if deprecated_callback:
scheduler = xgb.callback.reset_learning_rate
else:
@@ -217,7 +218,10 @@ class TestCallbacks:
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 4
warning_check = pytest.warns(UserWarning) if deprecated_callback else tm.noop_context()
if deprecated_callback:
warning_check = pytest.warns(UserWarning)
else:
warning_check = tm.noop_context()
# learning_rates as a list
# init eta with 0 to check whether learning_rates work
@@ -288,17 +292,22 @@ class TestCallbacks:
for i in range(1, len(eval_errors_0)):
assert eval_errors_3[i] != eval_errors_2[i]
def test_eta_decay_hist(self):
self.run_eta_decay('hist', True)
self.run_eta_decay('hist', False)
with warning_check:
xgb.cv(param, dtrain, num_round, callbacks=[scheduler(eta_decay)])
def test_eta_decay_approx(self):
self.run_eta_decay('approx', True)
self.run_eta_decay('approx', False)
def test_eta_decay_exact(self):
self.run_eta_decay('exact', True)
self.run_eta_decay('exact', False)
@pytest.mark.parametrize(
"tree_method, deprecated_callback",
[
("hist", True),
("hist", False),
("approx", True),
("approx", False),
("exact", True),
("exact", False),
],
)
def test_eta_decay(self, tree_method, deprecated_callback):
self.run_eta_decay(tree_method, deprecated_callback)
def test_check_point(self):
from sklearn.datasets import load_breast_cancer