Fix parameter loading with training continuation. (#7121)
* Add a demo for training continuation.
This commit is contained in:
parent
41e882f80b
commit
778135f657
@ -14,3 +14,5 @@ XGBoost Python Feature Walkthrough
|
|||||||
* [Sklearn access evals result](sklearn_evals_result.py)
|
* [Sklearn access evals result](sklearn_evals_result.py)
|
||||||
* [Access evals result](evals_result.py)
|
* [Access evals result](evals_result.py)
|
||||||
* [External Memory](external_memory.py)
|
* [External Memory](external_memory.py)
|
||||||
|
* [Training continuation](continuation.py)
|
||||||
|
* [Feature weights for column sampling](feature_weights.py)
|
||||||
|
|||||||
109
demo/guide-python/continuation.py
Normal file
109
demo/guide-python/continuation.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
"""
|
||||||
|
Demo for training continuation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sklearn.datasets import load_breast_cancer
|
||||||
|
import xgboost
|
||||||
|
import pickle
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def training_continuation(tmpdir: str, use_pickle: bool) -> None:
|
||||||
|
"""Basic training continuation."""
|
||||||
|
# Train 128 iterations in 1 session
|
||||||
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
|
clf = xgboost.XGBClassifier(n_estimators=128, use_label_encoder=False)
|
||||||
|
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss")
|
||||||
|
print("Total boosted rounds:", clf.get_booster().num_boosted_rounds())
|
||||||
|
|
||||||
|
# Train 128 iterations in 2 sessions, with the first one runs for 32 iterations and
|
||||||
|
# the second one runs for 96 iterations
|
||||||
|
clf = xgboost.XGBClassifier(n_estimators=32, use_label_encoder=False)
|
||||||
|
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss")
|
||||||
|
assert clf.get_booster().num_boosted_rounds() == 32
|
||||||
|
|
||||||
|
# load back the model, this could be a checkpoint
|
||||||
|
if use_pickle:
|
||||||
|
path = os.path.join(tmpdir, "model-first-32.pkl")
|
||||||
|
with open(path, "wb") as fd:
|
||||||
|
pickle.dump(clf, fd)
|
||||||
|
with open(path, "rb") as fd:
|
||||||
|
loaded = pickle.load(fd)
|
||||||
|
else:
|
||||||
|
path = os.path.join(tmpdir, "model-first-32.json")
|
||||||
|
clf.save_model(path)
|
||||||
|
loaded = xgboost.XGBClassifier()
|
||||||
|
loaded.load_model(path)
|
||||||
|
|
||||||
|
clf = xgboost.XGBClassifier(n_estimators=128 - 32)
|
||||||
|
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss", xgb_model=loaded)
|
||||||
|
|
||||||
|
print("Total boosted rounds:", clf.get_booster().num_boosted_rounds())
|
||||||
|
|
||||||
|
assert clf.get_booster().num_boosted_rounds() == 128
|
||||||
|
|
||||||
|
|
||||||
|
def training_continuation_early_stop(tmpdir: str, use_pickle: bool) -> None:
|
||||||
|
"""Training continuation with early stopping."""
|
||||||
|
early_stopping_rounds = 5
|
||||||
|
early_stop = xgboost.callback.EarlyStopping(
|
||||||
|
rounds=early_stopping_rounds, save_best=True
|
||||||
|
)
|
||||||
|
n_estimators = 512
|
||||||
|
|
||||||
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
|
clf = xgboost.XGBClassifier(n_estimators=n_estimators, use_label_encoder=False)
|
||||||
|
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss", callbacks=[early_stop])
|
||||||
|
print("Total boosted rounds:", clf.get_booster().num_boosted_rounds())
|
||||||
|
best = clf.best_iteration
|
||||||
|
|
||||||
|
# Train 512 iterations in 2 sessions, with the first one runs for 128 iterations and
|
||||||
|
# the second one runs until early stop.
|
||||||
|
clf = xgboost.XGBClassifier(n_estimators=128, use_label_encoder=False)
|
||||||
|
# Reinitialize the early stop callback
|
||||||
|
early_stop = xgboost.callback.EarlyStopping(
|
||||||
|
rounds=early_stopping_rounds, save_best=True
|
||||||
|
)
|
||||||
|
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss", callbacks=[early_stop])
|
||||||
|
assert clf.get_booster().num_boosted_rounds() == 128
|
||||||
|
|
||||||
|
# load back the model, this could be a checkpoint
|
||||||
|
if use_pickle:
|
||||||
|
path = os.path.join(tmpdir, "model-first-128.pkl")
|
||||||
|
with open(path, "wb") as fd:
|
||||||
|
pickle.dump(clf, fd)
|
||||||
|
with open(path, "rb") as fd:
|
||||||
|
loaded = pickle.load(fd)
|
||||||
|
else:
|
||||||
|
path = os.path.join(tmpdir, "model-first-128.json")
|
||||||
|
clf.save_model(path)
|
||||||
|
loaded = xgboost.XGBClassifier(use_label_encoder=False)
|
||||||
|
loaded.load_model(path)
|
||||||
|
|
||||||
|
early_stop = xgboost.callback.EarlyStopping(
|
||||||
|
rounds=early_stopping_rounds, save_best=True
|
||||||
|
)
|
||||||
|
clf = xgboost.XGBClassifier(
|
||||||
|
n_estimators=n_estimators - 128, use_label_encoder=False
|
||||||
|
)
|
||||||
|
clf.fit(
|
||||||
|
X,
|
||||||
|
y,
|
||||||
|
eval_set=[(X, y)],
|
||||||
|
eval_metric="logloss",
|
||||||
|
callbacks=[early_stop],
|
||||||
|
xgb_model=loaded,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Total boosted rounds:", clf.get_booster().num_boosted_rounds())
|
||||||
|
assert clf.best_iteration == best
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
training_continuation_early_stop(tmpdir, False)
|
||||||
|
training_continuation_early_stop(tmpdir, True)
|
||||||
|
|
||||||
|
training_continuation(tmpdir, True)
|
||||||
|
training_continuation(tmpdir, False)
|
||||||
@ -1298,19 +1298,6 @@ class Booster(object):
|
|||||||
# Validate feature only after the feature names are saved into booster.
|
# Validate feature only after the feature names are saved into booster.
|
||||||
self._validate_features(d)
|
self._validate_features(d)
|
||||||
|
|
||||||
params = params or {}
|
|
||||||
params = self._configure_metrics(params.copy())
|
|
||||||
params = self._configure_constraints(params)
|
|
||||||
if isinstance(params, list):
|
|
||||||
params.append(('validate_parameters', True))
|
|
||||||
else:
|
|
||||||
params['validate_parameters'] = True
|
|
||||||
|
|
||||||
self.set_param(params or {})
|
|
||||||
if (params is not None) and ('booster' in params):
|
|
||||||
self.booster = params['booster']
|
|
||||||
else:
|
|
||||||
self.booster = 'gbtree'
|
|
||||||
if isinstance(model_file, Booster):
|
if isinstance(model_file, Booster):
|
||||||
assert self.handle is not None
|
assert self.handle is not None
|
||||||
# We use the pickle interface for getting memory snapshot from
|
# We use the pickle interface for getting memory snapshot from
|
||||||
@ -1330,6 +1317,20 @@ class Booster(object):
|
|||||||
else:
|
else:
|
||||||
raise TypeError('Unknown type:', model_file)
|
raise TypeError('Unknown type:', model_file)
|
||||||
|
|
||||||
|
params = params or {}
|
||||||
|
params = self._configure_metrics(params.copy())
|
||||||
|
params = self._configure_constraints(params)
|
||||||
|
if isinstance(params, list):
|
||||||
|
params.append(('validate_parameters', True))
|
||||||
|
else:
|
||||||
|
params['validate_parameters'] = True
|
||||||
|
|
||||||
|
self.set_param(params or {})
|
||||||
|
if (params is not None) and ('booster' in params):
|
||||||
|
self.booster = params['booster']
|
||||||
|
else:
|
||||||
|
self.booster = 'gbtree'
|
||||||
|
|
||||||
def _configure_metrics(self, params: Union[Dict, List]) -> Union[Dict, List]:
|
def _configure_metrics(self, params: Union[Dict, List]) -> Union[Dict, List]:
|
||||||
if isinstance(params, dict) and 'eval_metric' in params \
|
if isinstance(params, dict) and 'eval_metric' in params \
|
||||||
and isinstance(params['eval_metric'], list):
|
and isinstance(params['eval_metric'], list):
|
||||||
|
|||||||
@ -127,6 +127,12 @@ def test_callbacks_demo():
|
|||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
|
def test_continuation_demo():
|
||||||
|
script = os.path.join(PYTHON_DEMO_DIR, 'continuation.py')
|
||||||
|
cmd = ['python', script]
|
||||||
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
# gpu_acceleration is not tested due to covertype dataset is being too huge.
|
# gpu_acceleration is not tested due to covertype dataset is being too huge.
|
||||||
# gamma regression is not tested as it requires running a R script first.
|
# gamma regression is not tested as it requires running a R script first.
|
||||||
# aft viz is not tested due to ploting is not controled
|
# aft viz is not tested due to ploting is not controled
|
||||||
|
|||||||
@ -3,6 +3,8 @@ import testing as tm
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
|
||||||
rng = np.random.RandomState(1337)
|
rng = np.random.RandomState(1337)
|
||||||
|
|
||||||
@ -145,3 +147,21 @@ class TestTrainingContinuation:
|
|||||||
for p in params:
|
for p in params:
|
||||||
p['updater'] = updaters
|
p['updater'] = updaters
|
||||||
self.run_training_continuation(params[0], params[1], params[2])
|
self.run_training_continuation(params[0], params[1], params[2])
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
|
def test_changed_parameter(self):
|
||||||
|
from sklearn.datasets import load_breast_cancer
|
||||||
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
|
clf = xgb.XGBClassifier(n_estimators=2, use_label_encoder=False)
|
||||||
|
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss")
|
||||||
|
assert tm.non_increasing(clf.evals_result()["validation_0"]["logloss"])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
clf.save_model(os.path.join(tmpdir, "clf.json"))
|
||||||
|
loaded = xgb.XGBClassifier(use_label_encoder=False)
|
||||||
|
loaded.load_model(os.path.join(tmpdir, "clf.json"))
|
||||||
|
|
||||||
|
clf = xgb.XGBClassifier(n_estimators=2, use_label_encoder=False)
|
||||||
|
# change metric to error
|
||||||
|
clf.fit(X, y, eval_set=[(X, y)], eval_metric="error")
|
||||||
|
assert tm.non_increasing(clf.evals_result()["validation_0"]["error"])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user