[breaking] [jvm-packages] Remove rabit check point. (#9599)
- Add `numBoostedRound` to jvm packages - Remove rabit checkpoint version. - Change the starting version of training continuation in JVM [breaking]. - Redefine the checkpoint version policy in jvm package. [breaking] - Rename the Python check point callback parameter. [breaking] - Unifies the checkpoint policy between Python and JVM.
This commit is contained in:
@@ -540,7 +540,10 @@ class EvaluationMonitor(TrainingCallback):
|
||||
|
||||
|
||||
class TrainingCheckPoint(TrainingCallback):
|
||||
"""Checkpointing operation.
|
||||
"""Checkpointing operation. Users are encouraged to create their own callbacks for
|
||||
checkpoint as XGBoost doesn't handle distributed file systems. When checkpointing on
|
||||
distributed systems, be sure to know the rank of the worker to avoid multiple
|
||||
workers checkpointing to the same place.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
@@ -553,9 +556,9 @@ class TrainingCheckPoint(TrainingCallback):
|
||||
pattern of output model file. Models will be saved as name_0.json, name_1.json,
|
||||
name_2.json ....
|
||||
as_pickle :
|
||||
When set to True, all training parameters will be saved in pickle format, instead
|
||||
of saving only the model.
|
||||
iterations :
|
||||
When set to True, all training parameters will be saved in pickle format,
|
||||
instead of saving only the model.
|
||||
interval :
|
||||
Interval of checkpointing. Checkpointing is slow so setting a larger number can
|
||||
reduce performance hit.
|
||||
|
||||
@@ -566,15 +569,20 @@ class TrainingCheckPoint(TrainingCallback):
|
||||
directory: Union[str, os.PathLike],
|
||||
name: str = "model",
|
||||
as_pickle: bool = False,
|
||||
iterations: int = 100,
|
||||
interval: int = 100,
|
||||
) -> None:
|
||||
self._path = os.fspath(directory)
|
||||
self._name = name
|
||||
self._as_pickle = as_pickle
|
||||
self._iterations = iterations
|
||||
self._epoch = 0
|
||||
self._iterations = interval
|
||||
self._epoch = 0 # counter for iterval
|
||||
self._start = 0 # beginning iteration
|
||||
super().__init__()
|
||||
|
||||
def before_training(self, model: _Model) -> _Model:
|
||||
self._start = model.num_boosted_rounds()
|
||||
return model
|
||||
|
||||
def after_iteration(
|
||||
self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog
|
||||
) -> bool:
|
||||
@@ -583,11 +591,12 @@ class TrainingCheckPoint(TrainingCallback):
|
||||
self._path,
|
||||
self._name
|
||||
+ "_"
|
||||
+ str(epoch)
|
||||
+ (str(epoch + self._start))
|
||||
+ (".pkl" if self._as_pickle else ".json"),
|
||||
)
|
||||
self._epoch = 0
|
||||
self._epoch = 0 # reset counter
|
||||
if collective.get_rank() == 0:
|
||||
# checkpoint using the first worker
|
||||
if self._as_pickle:
|
||||
with open(path, "wb") as fd:
|
||||
pickle.dump(model, fd)
|
||||
|
||||
Reference in New Issue
Block a user