Use UBJ in Python checkpoint. (#9958)

This commit is contained in:
Jiaming Yuan 2024-01-09 03:22:15 +08:00 committed by GitHub
parent fa5e2f6c45
commit b3eb5d0945
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 104 additions and 46 deletions

View File

@ -7,6 +7,7 @@ Demo for using and defining callback functions
import argparse
import os
import tempfile
from typing import Dict
import numpy as np
from matplotlib import pyplot as plt
@ -17,24 +18,26 @@ import xgboost as xgb
class Plotting(xgb.callback.TrainingCallback):
"""Plot evaluation result during training. Only for demonstration purpose as it's quite
slow to draw.
"""Plot evaluation result during training. Only for demonstration purpose as it's
quite slow to draw using matplotlib.
"""
def __init__(self, rounds):
def __init__(self, rounds: int) -> None:
self.fig = plt.figure()
self.ax = self.fig.add_subplot(111)
self.rounds = rounds
self.lines = {}
self.lines: Dict[str, plt.Line2D] = {}
self.fig.show()
self.x = np.linspace(0, self.rounds, self.rounds)
plt.ion()
def _get_key(self, data, metric):
def _get_key(self, data: str, metric: str) -> str:
return f"{data}-{metric}"
def after_iteration(self, model, epoch, evals_log):
def after_iteration(
self, model: xgb.Booster, epoch: int, evals_log: Dict[str, dict]
) -> bool:
"""Update the plot."""
if not self.lines:
for data, metric in evals_log.items():
@ -55,7 +58,7 @@ class Plotting(xgb.callback.TrainingCallback):
return False
def custom_callback():
def custom_callback() -> None:
"""Demo for defining a custom callback function that plots evaluation result during
training."""
X, y = load_breast_cancer(return_X_y=True)
@ -82,19 +85,27 @@ def custom_callback():
)
def check_point_callback():
# only for demo, set a larger value (like 100) in practice as checkpointing is quite
def check_point_callback() -> None:
"""Demo for using the checkpoint callback. Custom logic for handling output is
usually required and users are encouraged to define their own callback for
checkpointing operations. The builtin one can be used as a starting point.
"""
# Only for demo, set a larger value (like 100) in practice as checkpointing is quite
# slow.
rounds = 2
def check(as_pickle):
def check(as_pickle: bool) -> None:
for i in range(0, 10, rounds):
if i == 0:
continue
if as_pickle:
path = os.path.join(tmpdir, "model_" + str(i) + ".pkl")
else:
path = os.path.join(tmpdir, "model_" + str(i) + ".json")
path = os.path.join(
tmpdir,
f"model_{i}.{xgb.callback.TrainingCheckPoint.default_format}",
)
assert os.path.exists(path)
X, y = load_breast_cancer(return_X_y=True)

View File

@ -88,22 +88,18 @@ Callback API
.. autoclass:: xgboost.callback.EvaluationMonitor
:members:
:inherited-members:
:show-inheritance:
.. autoclass:: xgboost.callback.EarlyStopping
:members:
:inherited-members:
:show-inheritance:
.. autoclass:: xgboost.callback.LearningRateScheduler
:members:
:inherited-members:
:show-inheritance:
.. autoclass:: xgboost.callback.TrainingCheckPoint
:members:
:inherited-members:
:show-inheritance:
.. _dask_api:

View File

@ -62,11 +62,31 @@ class TrainingCallback(ABC):
return model
def before_iteration(self, model: _Model, epoch: int, evals_log: EvalsLog) -> bool:
"""Run before each iteration. Return True when training should stop."""
"""Run before each iteration. Returns True when training should stop. See
:py:meth:`after_iteration` for details.
"""
return False
def after_iteration(self, model: _Model, epoch: int, evals_log: EvalsLog) -> bool:
"""Run after each iteration. Return True when training should stop."""
"""Run after each iteration. Returns `True` when training should stop.
Parameters
----------
model :
Eeither a :py:class:`~xgboost.Booster` object or a CVPack if the cv function
in xgboost is being used.
epoch :
The current training iteration.
evals_log :
A dictionary containing the evaluation history:
.. code-block:: python
{"data_name": {"metric_name": [0.5, ...]}}
"""
return False
@ -547,14 +567,16 @@ class TrainingCheckPoint(TrainingCallback):
.. versionadded:: 1.3.0
Since XGBoost 2.1.0, the default format is changed to UBJSON.
Parameters
----------
directory :
Output model directory.
name :
pattern of output model file. Models will be saved as name_0.json, name_1.json,
name_2.json ....
pattern of output model file. Models will be saved as name_0.ubj, name_1.ubj,
name_2.ubj ....
as_pickle :
When set to True, all training parameters will be saved in pickle format,
instead of saving only the model.
@ -564,6 +586,8 @@ class TrainingCheckPoint(TrainingCallback):
"""
default_format = "ubj"
def __init__(
self,
directory: Union[str, os.PathLike],
@ -592,7 +616,7 @@ class TrainingCheckPoint(TrainingCallback):
self._name
+ "_"
+ (str(epoch + self._start))
+ (".pkl" if self._as_pickle else ".json"),
+ (".pkl" if self._as_pickle else f".{self.default_format}"),
)
self._epoch = 0 # reset counter
if collective.get_rank() == 0:

View File

@ -2591,9 +2591,8 @@ class Booster:
The model is saved in an XGBoost internal format which is universal among the
various XGBoost interfaces. Auxiliary attributes of the Python Booster object
(such as feature_names) will not be saved when using binary format. To save
those attributes, use JSON/UBJ instead. See :doc:`Model IO
</tutorials/saving_model>` for more info.
(such as feature_names) are only saved when using JSON or UBJSON (default)
format. See :doc:`Model IO </tutorials/saving_model>` for more info.
.. code-block:: python
@ -2616,12 +2615,15 @@ class Booster:
def save_raw(self, raw_format: str = "ubj") -> bytearray:
"""Save the model to a in memory buffer representation instead of file.
The model is saved in an XGBoost internal format which is universal among the
various XGBoost interfaces. Auxiliary attributes of the Python Booster object
(such as feature_names) are only saved when using JSON or UBJSON (default)
format. See :doc:`Model IO </tutorials/saving_model>` for more info.
Parameters
----------
raw_format :
Format of output buffer. Can be `json`, `ubj` or `deprecated`. Right now
the default is `deprecated` but it will be changed to `ubj` (univeral binary
json) in the future.
Format of output buffer. Can be `json`, `ubj` or `deprecated`.
Returns
-------
@ -2640,11 +2642,10 @@ class Booster:
def load_model(self, fname: ModelIn) -> None:
"""Load the model from a file or a bytearray.
The model is loaded from XGBoost format which is universal among the various
XGBoost interfaces. Auxiliary attributes of the Python Booster object (such as
feature_names) will not be loaded when using binary format. To save those
attributes, use JSON/UBJ instead. See :doc:`Model IO </tutorials/saving_model>`
for more info.
The model is saved in an XGBoost internal format which is universal among the
various XGBoost interfaces. Auxiliary attributes of the Python Booster object
(such as feature_names) are only saved when using JSON or UBJSON (default)
format. See :doc:`Model IO </tutorials/saving_model>` for more info.
.. code-block:: python
@ -2769,9 +2770,9 @@ class Booster:
with_stats: bool = False,
dump_format: str = "text",
) -> List[str]:
"""Returns the model dump as a list of strings. Unlike :py:meth:`save_model`, the output
format is primarily used for visualization or interpretation, hence it's more
human readable but cannot be loaded back to XGBoost.
"""Returns the model dump as a list of strings. Unlike :py:meth:`save_model`,
the output format is primarily used for visualization or interpretation, hence
it's more human readable but cannot be loaded back to XGBoost.
Parameters
----------

View File

@ -31,6 +31,8 @@ class LintersPaths:
"tests/python/test_with_pandas.py",
"tests/python-gpu/",
"tests/python-sycl/",
"tests/test_distributed/test_with_dask/",
"tests/test_distributed/test_gpu_with_dask/",
"tests/test_distributed/test_with_spark/",
"tests/test_distributed/test_gpu_with_spark/",
# demo
@ -91,6 +93,7 @@ class LintersPaths:
# demo
"demo/json-model/json_parser.py",
"demo/guide-python/external_memory.py",
"demo/guide-python/callbacks.py",
"demo/guide-python/cat_in_the_dat.py",
"demo/guide-python/categorical.py",
"demo/guide-python/cat_pipeline.py",

View File

@ -244,7 +244,7 @@ class TestCallbacks:
assert booster.num_boosted_rounds() == booster.best_iteration + 1
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'model.json')
path = os.path.join(tmpdir, "model.json")
cls.save_model(path)
cls = xgb.XGBClassifier()
cls.load_model(path)
@ -378,7 +378,7 @@ class TestCallbacks:
scheduler = xgb.callback.LearningRateScheduler
dtrain, dtest = tm.load_agaricus(__file__)
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
watchlist = [(dtest, "eval"), (dtrain, "train")]
param = {
"max_depth": 2,
@ -429,7 +429,7 @@ class TestCallbacks:
assert tree_3th_0["split_conditions"] != tree_3th_1["split_conditions"]
@pytest.mark.parametrize("tree_method", ["hist", "approx", "approx"])
def test_eta_decay(self, tree_method):
def test_eta_decay(self, tree_method: str) -> None:
self.run_eta_decay(tree_method)
@pytest.mark.parametrize(
@ -446,7 +446,7 @@ class TestCallbacks:
def test_eta_decay_leaf_output(self, tree_method: str, objective: str) -> None:
self.run_eta_decay_leaf_output(tree_method, objective)
def test_check_point(self):
def test_check_point(self) -> None:
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
@ -463,7 +463,12 @@ class TestCallbacks:
callbacks=[check_point],
)
for i in range(1, 10):
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".json"))
assert os.path.exists(
os.path.join(
tmpdir,
f"model_{i}.{xgb.callback.TrainingCheckPoint.default_format}",
)
)
check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, interval=1, as_pickle=True, name="model"
@ -478,7 +483,7 @@ class TestCallbacks:
for i in range(1, 10):
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".pkl"))
def test_callback_list(self):
def test_callback_list(self) -> None:
X, y = tm.data.get_california_housing()
m = xgb.DMatrix(X, y)
callbacks = [xgb.callback.EarlyStopping(rounds=10)]

View File

@ -1590,7 +1590,7 @@ class TestWithDask:
@given(
params=hist_parameter_strategy,
cache_param=hist_cache_strategy,
dataset=tm.make_dataset_strategy()
dataset=tm.make_dataset_strategy(),
)
@settings(
deadline=None, max_examples=10, suppress_health_check=suppress, print_blob=True
@ -2250,16 +2250,27 @@ class TestDaskCallbacks:
],
)
for i in range(1, 10):
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".json"))
assert os.path.exists(
os.path.join(
tmpdir,
f"model_{i}.{xgb.callback.TrainingCheckPoint.default_format}",
)
)
@gen_cluster(client=True, clean_kwargs={"processes": False, "threads": False}, allow_unclosed=True)
@gen_cluster(
client=True,
clean_kwargs={"processes": False, "threads": False},
allow_unclosed=True,
)
async def test_worker_left(c, s, a, b):
async with Worker(s.address):
dx = da.random.random((1000, 10)).rechunk(chunks=(10, None))
dy = da.random.random((1000,)).rechunk(chunks=(10,))
d_train = await xgb.dask.DaskDMatrix(
c, dx, dy,
c,
dx,
dy,
)
await async_poll_for(lambda: len(s.workers) == 2, timeout=5)
with pytest.raises(RuntimeError, match="Missing"):
@ -2271,12 +2282,19 @@ async def test_worker_left(c, s, a, b):
)
@gen_cluster(client=True, Worker=Nanny, clean_kwargs={"processes": False, "threads": False}, allow_unclosed=True)
@gen_cluster(
client=True,
Worker=Nanny,
clean_kwargs={"processes": False, "threads": False},
allow_unclosed=True,
)
async def test_worker_restarted(c, s, a, b):
dx = da.random.random((1000, 10)).rechunk(chunks=(10, None))
dy = da.random.random((1000,)).rechunk(chunks=(10,))
d_train = await xgb.dask.DaskDMatrix(
c, dx, dy,
c,
dx,
dy,
)
await c.restart_workers([a.worker_address])
with pytest.raises(RuntimeError, match="Missing"):