[breaking] [jvm-packages] Remove rabit check point. (#9599)

- Add `numBoostedRound` to jvm packages
- Remove rabit checkpoint version.
- Change the starting version of training continuation in JVM [breaking].
- Redefine the checkpoint version policy in jvm package. [breaking]
- Rename the Python check point callback parameter. [breaking]
- Unifies the checkpoint policy between Python and JVM.
This commit is contained in:
Jiaming Yuan
2023-09-26 18:06:34 +08:00
committed by GitHub
parent 7901a299b2
commit c75a3bc0a9
15 changed files with 138 additions and 229 deletions

View File

@@ -540,7 +540,10 @@ class EvaluationMonitor(TrainingCallback):
class TrainingCheckPoint(TrainingCallback):
"""Checkpointing operation.
"""Checkpointing operation. Users are encouraged to create their own callbacks for
checkpoint as XGBoost doesn't handle distributed file systems. When checkpointing on
distributed systems, be sure to know the rank of the worker to avoid multiple
workers checkpointing to the same place.
.. versionadded:: 1.3.0
@@ -553,9 +556,9 @@ class TrainingCheckPoint(TrainingCallback):
pattern of output model file. Models will be saved as name_0.json, name_1.json,
name_2.json ....
as_pickle :
When set to True, all training parameters will be saved in pickle format, instead
of saving only the model.
iterations :
When set to True, all training parameters will be saved in pickle format,
instead of saving only the model.
interval :
Interval of checkpointing. Checkpointing is slow so setting a larger number can
reduce performance hit.
@@ -566,15 +569,20 @@ class TrainingCheckPoint(TrainingCallback):
directory: Union[str, os.PathLike],
name: str = "model",
as_pickle: bool = False,
iterations: int = 100,
interval: int = 100,
) -> None:
self._path = os.fspath(directory)
self._name = name
self._as_pickle = as_pickle
self._iterations = iterations
self._epoch = 0
self._iterations = interval
self._epoch = 0 # counter for iterval
self._start = 0 # beginning iteration
super().__init__()
def before_training(self, model: _Model) -> _Model:
self._start = model.num_boosted_rounds()
return model
def after_iteration(
self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog
) -> bool:
@@ -583,11 +591,12 @@ class TrainingCheckPoint(TrainingCallback):
self._path,
self._name
+ "_"
+ str(epoch)
+ (str(epoch + self._start))
+ (".pkl" if self._as_pickle else ".json"),
)
self._epoch = 0
self._epoch = 0 # reset counter
if collective.get_rank() == 0:
# checkpoint using the first worker
if self._as_pickle:
with open(path, "wb") as fd:
pickle.dump(model, fd)