Use UBJ in Python checkpoint. (#9958)
This commit is contained in:
@@ -7,6 +7,7 @@ Demo for using and defining callback functions
|
||||
import argparse
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
@@ -17,24 +18,26 @@ import xgboost as xgb
|
||||
|
||||
|
||||
class Plotting(xgb.callback.TrainingCallback):
|
||||
"""Plot evaluation result during training. Only for demonstration purpose as it's quite
|
||||
slow to draw.
|
||||
"""Plot evaluation result during training. Only for demonstration purpose as it's
|
||||
quite slow to draw using matplotlib.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, rounds):
|
||||
def __init__(self, rounds: int) -> None:
|
||||
self.fig = plt.figure()
|
||||
self.ax = self.fig.add_subplot(111)
|
||||
self.rounds = rounds
|
||||
self.lines = {}
|
||||
self.lines: Dict[str, plt.Line2D] = {}
|
||||
self.fig.show()
|
||||
self.x = np.linspace(0, self.rounds, self.rounds)
|
||||
plt.ion()
|
||||
|
||||
def _get_key(self, data, metric):
|
||||
def _get_key(self, data: str, metric: str) -> str:
|
||||
return f"{data}-{metric}"
|
||||
|
||||
def after_iteration(self, model, epoch, evals_log):
|
||||
def after_iteration(
|
||||
self, model: xgb.Booster, epoch: int, evals_log: Dict[str, dict]
|
||||
) -> bool:
|
||||
"""Update the plot."""
|
||||
if not self.lines:
|
||||
for data, metric in evals_log.items():
|
||||
@@ -55,7 +58,7 @@ class Plotting(xgb.callback.TrainingCallback):
|
||||
return False
|
||||
|
||||
|
||||
def custom_callback():
|
||||
def custom_callback() -> None:
|
||||
"""Demo for defining a custom callback function that plots evaluation result during
|
||||
training."""
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
@@ -82,19 +85,27 @@ def custom_callback():
|
||||
)
|
||||
|
||||
|
||||
def check_point_callback():
|
||||
# only for demo, set a larger value (like 100) in practice as checkpointing is quite
|
||||
def check_point_callback() -> None:
|
||||
"""Demo for using the checkpoint callback. Custom logic for handling output is
|
||||
usually required and users are encouraged to define their own callback for
|
||||
checkpointing operations. The builtin one can be used as a starting point.
|
||||
|
||||
"""
|
||||
# Only for demo, set a larger value (like 100) in practice as checkpointing is quite
|
||||
# slow.
|
||||
rounds = 2
|
||||
|
||||
def check(as_pickle):
|
||||
def check(as_pickle: bool) -> None:
|
||||
for i in range(0, 10, rounds):
|
||||
if i == 0:
|
||||
continue
|
||||
if as_pickle:
|
||||
path = os.path.join(tmpdir, "model_" + str(i) + ".pkl")
|
||||
else:
|
||||
path = os.path.join(tmpdir, "model_" + str(i) + ".json")
|
||||
path = os.path.join(
|
||||
tmpdir,
|
||||
f"model_{i}.{xgb.callback.TrainingCheckPoint.default_format}",
|
||||
)
|
||||
assert os.path.exists(path)
|
||||
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
|
||||
Reference in New Issue
Block a user