Use UBJ in Python checkpoint. (#9958)

This commit is contained in:
Jiaming Yuan
2024-01-09 03:22:15 +08:00
committed by GitHub
parent fa5e2f6c45
commit b3eb5d0945
7 changed files with 104 additions and 46 deletions

View File

@@ -7,6 +7,7 @@ Demo for using and defining callback functions
import argparse
import os
import tempfile
from typing import Dict
import numpy as np
from matplotlib import pyplot as plt
@@ -17,24 +18,26 @@ import xgboost as xgb
class Plotting(xgb.callback.TrainingCallback):
"""Plot evaluation result during training. Only for demonstration purpose as it's quite
slow to draw.
"""Plot evaluation result during training. Only for demonstration purpose as it's
quite slow to draw using matplotlib.
"""
def __init__(self, rounds):
def __init__(self, rounds: int) -> None:
self.fig = plt.figure()
self.ax = self.fig.add_subplot(111)
self.rounds = rounds
self.lines = {}
self.lines: Dict[str, plt.Line2D] = {}
self.fig.show()
self.x = np.linspace(0, self.rounds, self.rounds)
plt.ion()
def _get_key(self, data, metric):
def _get_key(self, data: str, metric: str) -> str:
return f"{data}-{metric}"
def after_iteration(self, model, epoch, evals_log):
def after_iteration(
self, model: xgb.Booster, epoch: int, evals_log: Dict[str, dict]
) -> bool:
"""Update the plot."""
if not self.lines:
for data, metric in evals_log.items():
@@ -55,7 +58,7 @@ class Plotting(xgb.callback.TrainingCallback):
return False
def custom_callback():
def custom_callback() -> None:
"""Demo for defining a custom callback function that plots evaluation result during
training."""
X, y = load_breast_cancer(return_X_y=True)
@@ -82,19 +85,27 @@ def custom_callback():
)
def check_point_callback():
# only for demo, set a larger value (like 100) in practice as checkpointing is quite
def check_point_callback() -> None:
"""Demo for using the checkpoint callback. Custom logic for handling output is
usually required and users are encouraged to define their own callback for
checkpointing operations. The builtin one can be used as a starting point.
"""
# Only for demo, set a larger value (like 100) in practice as checkpointing is quite
# slow.
rounds = 2
def check(as_pickle):
def check(as_pickle: bool) -> None:
for i in range(0, 10, rounds):
if i == 0:
continue
if as_pickle:
path = os.path.join(tmpdir, "model_" + str(i) + ".pkl")
else:
path = os.path.join(tmpdir, "model_" + str(i) + ".json")
path = os.path.join(
tmpdir,
f"model_{i}.{xgb.callback.TrainingCheckPoint.default_format}",
)
assert os.path.exists(path)
X, y = load_breast_cancer(return_X_y=True)