Rework Python callback functions. (#6199)
* Define a new callback interface for Python. * Deprecate the old callbacks. * Enable early stopping on dask.
This commit is contained in:
@@ -328,7 +328,7 @@ def test_sklearn_grid_search():
|
||||
reg.client = client
|
||||
model = GridSearchCV(reg, {'max_depth': [2, 4],
|
||||
'n_estimators': [5, 10]},
|
||||
cv=2, verbose=1, iid=True)
|
||||
cv=2, verbose=1)
|
||||
model.fit(X, y)
|
||||
# Expect unique results for each parameter value This confirms
|
||||
# sklearn is able to successfully update the parameter
|
||||
@@ -705,3 +705,42 @@ class TestWithDask:
|
||||
@pytest.mark.gtest
|
||||
def test_quantile_same_on_all_workers(self):
|
||||
self.run_quantile('SameOnAllWorkers')
|
||||
|
||||
|
||||
class TestDaskCallbacks:
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_early_stopping(self, client):
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
X, y = da.from_array(X), da.from_array(y)
|
||||
m = xgb.dask.DaskDMatrix(client, X, y)
|
||||
early_stopping_rounds = 5
|
||||
booster = xgb.dask.train(client, {'objective': 'binary:logistic',
|
||||
'eval_metric': 'error',
|
||||
'tree_method': 'hist'}, m,
|
||||
evals=[(m, 'Train')],
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=early_stopping_rounds)['booster']
|
||||
assert hasattr(booster, 'best_score')
|
||||
assert booster.best_iteration == 10
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_early_stopping_custom_eval(self, client):
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
X, y = da.from_array(X), da.from_array(y)
|
||||
m = xgb.dask.DaskDMatrix(client, X, y)
|
||||
early_stopping_rounds = 5
|
||||
booster = xgb.dask.train(
|
||||
client, {'objective': 'binary:logistic',
|
||||
'eval_metric': 'error',
|
||||
'tree_method': 'hist'}, m,
|
||||
evals=[(m, 'Train')],
|
||||
feval=tm.eval_error_metric,
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=early_stopping_rounds)['booster']
|
||||
assert hasattr(booster, 'best_score')
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||
|
||||
Reference in New Issue
Block a user