[breaking] [jvm-packages] Remove rabit check point. (#9599)

- Add `numBoostedRound` to jvm packages
- Remove rabit checkpoint version.
- Change the starting version of training continuation in JVM [breaking].
- Redefine the checkpoint version policy in jvm package. [breaking]
- Rename the Python check point callback parameter. [breaking]
- Unifies the checkpoint policy between Python and JVM.
This commit is contained in:
Jiaming Yuan 2023-09-26 18:06:34 +08:00 committed by GitHub
parent 7901a299b2
commit c75a3bc0a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 138 additions and 229 deletions

View File

@ -104,7 +104,7 @@ def check_point_callback():
# Use callback class from xgboost.callback
# Feel free to subclass/customize it to suit your need.
check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, iterations=rounds, name="model"
directory=tmpdir, interval=rounds, name="model"
)
xgb.train(
{"objective": "binary:logistic"},
@ -118,7 +118,7 @@ def check_point_callback():
# This version of checkpoint saves everything including parameters and
# model. See: doc/tutorials/saving_model.rst
check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, iterations=rounds, as_pickle=True, name="model"
directory=tmpdir, interval=rounds, as_pickle=True, name="model"
)
xgb.train(
{"objective": "binary:logistic"},

View File

@ -1308,24 +1308,6 @@ XGB_DLL int XGBoosterSerializeToBuffer(BoosterHandle handle, bst_ulong *out_len,
XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle,
const void *buf, bst_ulong len);
/*!
* \brief Initialize the booster from rabit checkpoint.
* This is used in distributed training API.
* \param handle handle
* \param version The output version of the model.
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
int* version);
/*!
* \brief Save the current checkpoint to rabit.
* \param handle handle
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle);
/*!
* \brief Save XGBoost's internal configuration into a JSON document. Currently the
* support is experimental, function signature may change in the future without

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -32,57 +32,53 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
}
private def createNewModels():
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
val (model4, model8) = {
val (model2, model4) = {
val training = buildDataFrame(Classification.train)
val paramMap = produceParamMap(tmpPath, 2)
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
}
(tmpPath, model4, model8)
(tmpPath, model2, model4)
}
test("test update/load models") {
val (tmpPath, model4, model8) = createNewModels()
val (tmpPath, model2, model4) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model4._booster.booster)
manager.updateCheckpoint(model2._booster.booster)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "4.model")
assert(manager.loadCheckpointAsScalaBooster().getVersion == 4)
assert(files.head.getPath.getName == "1.model")
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 2)
manager.updateCheckpoint(model8._booster)
manager.updateCheckpoint(model4._booster)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
assert(manager.loadCheckpointAsScalaBooster().getVersion == 8)
assert(files.head.getPath.getName == "3.model")
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4)
}
test("test cleanUpHigherVersions") {
val (tmpPath, model4, model8) = createNewModels()
val (tmpPath, model2, model4) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model8._booster)
manager.cleanUpHigherVersions(8)
assert(new File(s"$tmpPath/8.model").exists())
manager.updateCheckpoint(model4._booster)
manager.cleanUpHigherVersions(3)
assert(new File(s"$tmpPath/3.model").exists())
manager.cleanUpHigherVersions(4)
assert(!new File(s"$tmpPath/8.model").exists())
manager.cleanUpHigherVersions(2)
assert(!new File(s"$tmpPath/3.model").exists())
}
test("test checkpoint rounds") {
import scala.collection.JavaConverters._
val (tmpPath, model4, model8) = createNewModels()
val (tmpPath, model2, model4) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
assertResult(Seq(7))(
manager.getCheckpointRounds(0, 7).asScala)
assertResult(Seq(2, 4, 6, 7))(
manager.getCheckpointRounds(2, 7).asScala)
manager.updateCheckpoint(model4._booster)
assertResult(Seq(4, 6, 7))(
manager.getCheckpointRounds(2, 7).asScala)
assertResult(Seq(2))(manager.getCheckpointRounds(0, 0, 3).asScala)
assertResult(Seq(0, 2, 4, 6))(manager.getCheckpointRounds(0, 2, 7).asScala)
assertResult(Seq(0, 2, 4, 6, 7))(manager.getCheckpointRounds(0, 2, 8).asScala)
}
@ -109,8 +105,8 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
// Check only one model is kept after training
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
assert(files.head.getPath.getName == "4.model")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.model")
// Train next model based on prev model
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
assert(error(tmpModel) >= error(prevModel._booster))

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -787,35 +787,6 @@ public class Booster implements Serializable, KryoSerializable {
return importanceMap;
}
/**
* Save the model as byte array representation.
* Write these bytes to a file will give compatible format with other xgboost bindings.
*
* If java natively support HDFS file API, use toByteArray and write the ByteArray
*
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
* @throws XGBoostError native error
*/
private String[] getDumpInfo(boolean withStats) throws XGBoostError {
int statsFlag = 0;
if (withStats) {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelEx(handle, "", statsFlag, "text",
modelInfos));
return modelInfos[0];
}
public int getVersion() {
return this.version;
}
public void setVersion(int version) {
this.version = version;
}
/**
* Save model into raw byte array. Currently it's using the deprecated format as
* default, which will be changed into `ubj` in future releases.
@ -841,29 +812,6 @@ public class Booster implements Serializable, KryoSerializable {
return bytes[0];
}
/**
* Load the booster model from thread-local rabit checkpoint.
* This is only used in distributed training.
* @return the stored version number of the checkpoint.
* @throws XGBoostError
*/
int loadRabitCheckpoint() throws XGBoostError {
int[] out = new int[1];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
version = out[0];
return version;
}
/**
* Save the booster model into thread-local rabit checkpoint and increment the version.
* This is only used in distributed training.
* @throws XGBoostError
*/
void saveRabitCheckpoint() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
version += 1;
}
/**
* Get number of model features.
* @return the number of features.
@ -874,6 +822,11 @@ public class Booster implements Serializable, KryoSerializable {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, numFeature));
return numFeature[0];
}
public int getNumBoostedRound() throws XGBoostError {
int[] numRound = new int[1];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumBoostedRound(this.handle, numRound));
return numRound[0];
}
/**
* Internal initialization function.

View File

@ -1,3 +1,18 @@
/*
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;
import java.io.IOException;
@ -15,7 +30,7 @@ public class ExternalCheckpointManager {
private Log logger = LogFactory.getLog("ExternalCheckpointManager");
private String modelSuffix = ".model";
private Path checkpointPath;
private Path checkpointPath; // directory for checkpoints
private FileSystem fs;
public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
@ -35,6 +50,7 @@ public class ExternalCheckpointManager {
if (!fs.exists(checkpointPath)) {
return new ArrayList<>();
} else {
// Get integer versions from a list of checkpoint files.
return Arrays.stream(fs.listStatus(checkpointPath))
.map(path -> path.getPath().getName())
.filter(fileName -> fileName.endsWith(modelSuffix))
@ -44,6 +60,11 @@ public class ExternalCheckpointManager {
}
}
private Integer latest(List<Integer> versions) {
return versions.stream()
.max(Comparator.comparing(Integer::valueOf)).get();
}
public void cleanPath() throws IOException {
fs.delete(checkpointPath, true);
}
@ -51,12 +72,11 @@ public class ExternalCheckpointManager {
public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
List<Integer> versions = getExistingVersions();
if (versions.size() > 0) {
int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
int latestVersion = this.latest(versions);
String checkpointPath = getPath(latestVersion);
InputStream in = fs.open(new Path(checkpointPath));
logger.info("loaded checkpoint from " + checkpointPath);
Booster booster = XGBoost.loadModel(in);
booster.setVersion(latestVersion);
return booster;
} else {
return null;
@ -65,13 +85,16 @@ public class ExternalCheckpointManager {
public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
List<String> prevModelPaths = getExistingVersions().stream()
.map(this::getPath).collect(Collectors.toList());
String eventualPath = getPath(boosterToCheckpoint.getVersion());
.map(this::getPath).collect(Collectors.toList());
// checkpointing is done after update, so n_rounds - 1 is the current iteration
// accounting for training continuation.
Integer iter = boosterToCheckpoint.getNumBoostedRound() - 1;
String eventualPath = getPath(iter);
String tempPath = eventualPath + "-" + UUID.randomUUID();
try (OutputStream out = fs.create(new Path(tempPath), true)) {
boosterToCheckpoint.saveModel(out);
fs.rename(new Path(tempPath), new Path(eventualPath));
logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
logger.info("saving checkpoint with version " + iter);
prevModelPaths.stream().forEach(path -> {
try {
fs.delete(new Path(path), true);
@ -83,7 +106,7 @@ public class ExternalCheckpointManager {
}
public void cleanUpHigherVersions(int currentRound) throws IOException {
getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
getExistingVersions().stream().filter(v -> v > currentRound).forEach(v -> {
try {
fs.delete(new Path(getPath(v)), true);
} catch (IOException e) {
@ -91,27 +114,26 @@ public class ExternalCheckpointManager {
}
});
}
public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds)
// Get a list of iterations that need checkpointing.
public List<Integer> getCheckpointRounds(
int firstRound, int checkpointInterval, int numOfRounds)
throws IOException {
int end = firstRound + numOfRounds; // exclusive
int lastRound = end - 1;
if (end - 1 < 0) {
throw new IllegalArgumentException("Inavlid `numOfRounds`.");
}
List<Integer> arr = new ArrayList<>();
if (checkpointInterval > 0) {
List<Integer> prevRounds =
getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
prevRounds.add(0);
int firstCheckpointRound = prevRounds.stream()
.max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
List<Integer> arr = new ArrayList<>();
for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) {
for (int i = firstRound; i < end; i += checkpointInterval) {
arr.add(i);
}
arr.add(numOfRounds);
return arr;
} else if (checkpointInterval <= 0) {
List<Integer> l = new ArrayList<Integer>();
l.add(numOfRounds);
return l;
} else {
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
}
if (!arr.contains(lastRound)) {
arr.add(lastRound);
}
return arr;
}
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014,2021 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -133,7 +133,7 @@ public class XGBoost {
int earlyStoppingRound) throws XGBoostError {
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
}
// save checkpoint if iter is in checkpointIterations
private static void saveCheckpoint(
Booster booster,
int iter,
@ -169,7 +169,6 @@ public class XGBoost {
int bestIteration;
List<String> names = new ArrayList<String>();
List<DMatrix> mats = new ArrayList<DMatrix>();
Set<Integer> checkpointIterations = new HashSet<>();
ExternalCheckpointManager ecm = null;
if (checkpointPath != null) {
ecm = new ExternalCheckpointManager(checkpointPath, fs);
@ -203,32 +202,30 @@ public class XGBoost {
booster = new Booster(params, allMats);
booster.setFeatureNames(dtrain.getFeatureNames());
booster.setFeatureTypes(dtrain.getFeatureTypes());
booster.loadRabitCheckpoint();
} else {
// Start training on an existing booster
booster.setParams(params);
}
Set<Integer> checkpointIterations = new HashSet<>();
if (ecm != null) {
checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
checkpointIterations = new HashSet<>(
ecm.getCheckpointRounds(booster.getNumBoostedRound(), checkpointInterval, numRounds));
}
boolean initial_best_score_flag = false;
boolean max_direction = false;
// begin to train
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
if (booster.getVersion() % 2 == 0) {
if (obj != null) {
booster.update(dtrain, obj);
} else {
booster.update(dtrain, iter);
}
saveCheckpoint(booster, iter, checkpointIterations, ecm);
booster.saveRabitCheckpoint();
for (int iter = 0; iter < numRounds; iter++) {
if (obj != null) {
booster.update(dtrain, iter, obj);
} else {
booster.update(dtrain, iter);
}
saveCheckpoint(booster, iter, checkpointIterations, ecm);
//evaluation
// evaluation
if (evalMats.length > 0) {
float[] metricsOut = new float[evalMats.length];
String evalInfo;
@ -285,7 +282,6 @@ public class XGBoost {
Communicator.communicatorPrint(evalInfo + '\n');
}
}
booster.saveRabitCheckpoint();
}
return booster;
}

View File

@ -140,10 +140,11 @@ class XGBoostJNI {
public final static native int XGBoosterGetAttrNames(long handle, String[][] out_strings);
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
public final static native int XGBoosterSetAttr(long handle, String key, String value);
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
public final static native int XGBoosterGetNumFeature(long handle, long[] feature);
public final static native int XGBoosterGetNumBoostedRound(long handle, int[] rounds);
// communicator functions
public final static native int CommunicatorInit(String[] args);
public final static native int CommunicatorFinalize();

View File

@ -326,7 +326,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
@throws(classOf[XGBoostError])
def getNumFeature: Long = booster.getNumFeature
def getVersion: Int = booster.getVersion
def getNumBoostedRound: Long = booster.getNumBoostedRound
/**
* Save model into a raw byte array. Available options are "json", "ubj" and "deprecated".

View File

@ -984,33 +984,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadRabitCheckpoint
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint
(JNIEnv *jenv , jclass jcls, jlong jhandle, jintArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
int version;
int ret = XGBoosterLoadRabitCheckpoint(handle, &version);
JVM_CHECK_CALL(ret);
jint jversion = version;
jenv->SetIntArrayRegion(jout, 0, 1, &jversion);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSaveRabitCheckpoint
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
BoosterHandle handle = (BoosterHandle) jhandle;
return XGBoosterSaveRabitCheckpoint(handle);
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetNumFeature
@ -1027,6 +1000,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFea
return ret;
}
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoostedRound(
JNIEnv *jenv, jclass, jlong jhandle, jintArray jout) {
BoosterHandle handle = (BoosterHandle)jhandle;
std::int32_t n_rounds{0};
auto ret = XGBoosterBoostedRounds(handle, &n_rounds);
JVM_CHECK_CALL(ret);
jint jn_rounds = n_rounds;
jenv->SetIntArrayRegion(jout, 0, 1, &jn_rounds);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorInit

View File

@ -287,22 +287,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr
(JNIEnv *, jclass, jlong, jstring, jstring);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadRabitCheckpoint
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSaveRabitCheckpoint
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
(JNIEnv *, jclass, jlong);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetNumFeature
@ -311,6 +295,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabit
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature
(JNIEnv *, jclass, jlong, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetNumBoostedRound
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoostedRound
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorInit

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -16,7 +16,6 @@
package ml.dmlc.xgboost4j.java;
import junit.framework.TestCase;
import org.junit.Assert;
import org.junit.Test;
import java.io.ByteArrayInputStream;
@ -31,7 +30,7 @@ import static org.junit.Assert.fail;
/**
* test cases for Booster Inplace Predict
*
*
* @author hzx and Sovrn
*/
public class BoosterImplTest {
@ -845,14 +844,12 @@ public class BoosterImplTest {
float tempBoosterError = eval.eval(tempBooster.predict(testMat, true, 0), testMat);
// Save tempBooster to bytestream and load back
int prevVersion = tempBooster.getVersion();
ByteArrayInputStream in = new ByteArrayInputStream(tempBooster.toByteArray());
tempBooster = XGBoost.loadModel(in);
in.close();
tempBooster.setVersion(prevVersion);
// Continue training using tempBooster
round = 4;
round = 2;
Booster booster2 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0, tempBooster);
float booster2error = eval.eval(booster2.predict(testMat, true, 0), testMat);
TestCase.assertTrue(booster1error == booster2error);

View File

@ -540,7 +540,10 @@ class EvaluationMonitor(TrainingCallback):
class TrainingCheckPoint(TrainingCallback):
"""Checkpointing operation.
"""Checkpointing operation. Users are encouraged to create their own callbacks for
checkpoint as XGBoost doesn't handle distributed file systems. When checkpointing on
distributed systems, be sure to know the rank of the worker to avoid multiple
workers checkpointing to the same place.
.. versionadded:: 1.3.0
@ -553,9 +556,9 @@ class TrainingCheckPoint(TrainingCallback):
pattern of output model file. Models will be saved as name_0.json, name_1.json,
name_2.json ....
as_pickle :
When set to True, all training parameters will be saved in pickle format, instead
of saving only the model.
iterations :
When set to True, all training parameters will be saved in pickle format,
instead of saving only the model.
interval :
Interval of checkpointing. Checkpointing is slow so setting a larger number can
reduce performance hit.
@ -566,15 +569,20 @@ class TrainingCheckPoint(TrainingCallback):
directory: Union[str, os.PathLike],
name: str = "model",
as_pickle: bool = False,
iterations: int = 100,
interval: int = 100,
) -> None:
self._path = os.fspath(directory)
self._name = name
self._as_pickle = as_pickle
self._iterations = iterations
self._epoch = 0
self._iterations = interval
self._epoch = 0 # counter for iterval
self._start = 0 # beginning iteration
super().__init__()
def before_training(self, model: _Model) -> _Model:
self._start = model.num_boosted_rounds()
return model
def after_iteration(
self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog
) -> bool:
@ -583,11 +591,12 @@ class TrainingCheckPoint(TrainingCallback):
self._path,
self._name
+ "_"
+ str(epoch)
+ (str(epoch + self._start))
+ (".pkl" if self._as_pickle else ".json"),
)
self._epoch = 0
self._epoch = 0 # reset counter
if collective.get_rank() == 0:
# checkpoint using the first worker
if self._as_pickle:
with open(path, "wb") as fd:
pickle.dump(model, fd)

View File

@ -1430,36 +1430,13 @@ XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle,
API_END();
}
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
int* version) {
API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Learner*>(handle);
xgboost_CHECK_C_ARG_PTR(version);
*version = rabit::LoadCheckPoint();
if (*version != 0) {
bst->Configure();
}
API_END();
}
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
API_BEGIN();
CHECK_HANDLE();
auto *learner = static_cast<Learner *>(handle);
learner->Configure();
rabit::CheckPoint();
API_END();
}
XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer,
int end_layer, int step,
XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer, int end_layer, int step,
BoosterHandle *out) {
API_BEGIN();
CHECK_HANDLE();
xgboost_CHECK_C_ARG_PTR(out);
auto* learner = static_cast<Learner*>(handle);
auto *learner = static_cast<Learner *>(handle);
bool out_of_bound = false;
auto p_out = learner->Slice(begin_layer, end_layer, step, &out_of_bound);
if (out_of_bound) {

View File

@ -443,7 +443,7 @@ class TestCallbacks:
m = xgb.DMatrix(X, y)
with tempfile.TemporaryDirectory() as tmpdir:
check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, iterations=1, name="model"
directory=tmpdir, interval=1, name="model"
)
xgb.train(
{"objective": "binary:logistic"},
@ -456,7 +456,7 @@ class TestCallbacks:
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".json"))
check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, iterations=1, as_pickle=True, name="model"
directory=tmpdir, interval=1, as_pickle=True, name="model"
)
xgb.train(
{"objective": "binary:logistic"},

View File

@ -2238,7 +2238,7 @@ class TestDaskCallbacks:
y,
callbacks=[
xgb.callback.TrainingCheckPoint(
directory=Path(tmpdir), iterations=1, name="model"
directory=Path(tmpdir), interval=1, name="model"
)
],
)