[breaking] [jvm-packages] Remove rabit check point. (#9599)

- Add `numBoostedRound` to jvm packages
- Remove rabit checkpoint version.
- Change the starting version of training continuation in JVM [breaking].
- Redefine the checkpoint version policy in jvm package. [breaking]
- Rename the Python check point callback parameter. [breaking]
- Unifies the checkpoint policy between Python and JVM.
This commit is contained in:
Jiaming Yuan 2023-09-26 18:06:34 +08:00 committed by GitHub
parent 7901a299b2
commit c75a3bc0a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 138 additions and 229 deletions

View File

@ -104,7 +104,7 @@ def check_point_callback():
# Use callback class from xgboost.callback # Use callback class from xgboost.callback
# Feel free to subclass/customize it to suit your need. # Feel free to subclass/customize it to suit your need.
check_point = xgb.callback.TrainingCheckPoint( check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, iterations=rounds, name="model" directory=tmpdir, interval=rounds, name="model"
) )
xgb.train( xgb.train(
{"objective": "binary:logistic"}, {"objective": "binary:logistic"},
@ -118,7 +118,7 @@ def check_point_callback():
# This version of checkpoint saves everything including parameters and # This version of checkpoint saves everything including parameters and
# model. See: doc/tutorials/saving_model.rst # model. See: doc/tutorials/saving_model.rst
check_point = xgb.callback.TrainingCheckPoint( check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, iterations=rounds, as_pickle=True, name="model" directory=tmpdir, interval=rounds, as_pickle=True, name="model"
) )
xgb.train( xgb.train(
{"objective": "binary:logistic"}, {"objective": "binary:logistic"},

View File

@ -1308,24 +1308,6 @@ XGB_DLL int XGBoosterSerializeToBuffer(BoosterHandle handle, bst_ulong *out_len,
XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle, XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle,
const void *buf, bst_ulong len); const void *buf, bst_ulong len);
/*!
* \brief Initialize the booster from rabit checkpoint.
* This is used in distributed training API.
* \param handle handle
* \param version The output version of the model.
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
int* version);
/*!
* \brief Save the current checkpoint to rabit.
* \param handle handle
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle);
/*! /*!
* \brief Save XGBoost's internal configuration into a JSON document. Currently the * \brief Save XGBoost's internal configuration into a JSON document. Currently the
* support is experimental, function signature may change in the future without * support is experimental, function signature may change in the future without

View File

@ -1,5 +1,5 @@
/* /*
Copyright (c) 2014-2022 by Contributors Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -34,55 +34,51 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
private def createNewModels(): private def createNewModels():
(String, XGBoostClassificationModel, XGBoostClassificationModel) = { (String, XGBoostClassificationModel, XGBoostClassificationModel) = {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString val tmpPath = createTmpFolder("test").toAbsolutePath.toString
val (model4, model8) = { val (model2, model4) = {
val training = buildDataFrame(Classification.train) val training = buildDataFrame(Classification.train)
val paramMap = produceParamMap(tmpPath, 2) val paramMap = produceParamMap(tmpPath, 2)
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training), (new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training)) new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
} }
(tmpPath, model4, model8) (tmpPath, model2, model4)
} }
test("test update/load models") { test("test update/load models") {
val (tmpPath, model4, model8) = createNewModels() val (tmpPath, model2, model4) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration)) val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model4._booster.booster) manager.updateCheckpoint(model2._booster.booster)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1) assert(files.length == 1)
assert(files.head.getPath.getName == "4.model") assert(files.head.getPath.getName == "1.model")
assert(manager.loadCheckpointAsScalaBooster().getVersion == 4) assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 2)
manager.updateCheckpoint(model8._booster) manager.updateCheckpoint(model4._booster)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1) assert(files.length == 1)
assert(files.head.getPath.getName == "8.model") assert(files.head.getPath.getName == "3.model")
assert(manager.loadCheckpointAsScalaBooster().getVersion == 8) assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4)
} }
test("test cleanUpHigherVersions") { test("test cleanUpHigherVersions") {
val (tmpPath, model4, model8) = createNewModels() val (tmpPath, model2, model4) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration)) val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model8._booster) manager.updateCheckpoint(model4._booster)
manager.cleanUpHigherVersions(8) manager.cleanUpHigherVersions(3)
assert(new File(s"$tmpPath/8.model").exists()) assert(new File(s"$tmpPath/3.model").exists())
manager.cleanUpHigherVersions(4) manager.cleanUpHigherVersions(2)
assert(!new File(s"$tmpPath/8.model").exists()) assert(!new File(s"$tmpPath/3.model").exists())
} }
test("test checkpoint rounds") { test("test checkpoint rounds") {
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
val (tmpPath, model4, model8) = createNewModels() val (tmpPath, model2, model4) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration)) val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
assertResult(Seq(7))( assertResult(Seq(2))(manager.getCheckpointRounds(0, 0, 3).asScala)
manager.getCheckpointRounds(0, 7).asScala) assertResult(Seq(0, 2, 4, 6))(manager.getCheckpointRounds(0, 2, 7).asScala)
assertResult(Seq(2, 4, 6, 7))( assertResult(Seq(0, 2, 4, 6, 7))(manager.getCheckpointRounds(0, 2, 8).asScala)
manager.getCheckpointRounds(2, 7).asScala)
manager.updateCheckpoint(model4._booster)
assertResult(Seq(4, 6, 7))(
manager.getCheckpointRounds(2, 7).asScala)
} }
@ -109,8 +105,8 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
// Check only one model is kept after training // Check only one model is kept after training
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1) assert(files.length == 1)
assert(files.head.getPath.getName == "8.model") assert(files.head.getPath.getName == "4.model")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model") val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.model")
// Train next model based on prev model // Train next model based on prev model
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training) val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
assert(error(tmpModel) >= error(prevModel._booster)) assert(error(tmpModel) >= error(prevModel._booster))

View File

@ -1,5 +1,5 @@
/* /*
Copyright (c) 2014-2022 by Contributors Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -787,35 +787,6 @@ public class Booster implements Serializable, KryoSerializable {
return importanceMap; return importanceMap;
} }
/**
* Save the model as byte array representation.
* Write these bytes to a file will give compatible format with other xgboost bindings.
*
* If java natively support HDFS file API, use toByteArray and write the ByteArray
*
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
* @throws XGBoostError native error
*/
private String[] getDumpInfo(boolean withStats) throws XGBoostError {
int statsFlag = 0;
if (withStats) {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelEx(handle, "", statsFlag, "text",
modelInfos));
return modelInfos[0];
}
public int getVersion() {
return this.version;
}
public void setVersion(int version) {
this.version = version;
}
/** /**
* Save model into raw byte array. Currently it's using the deprecated format as * Save model into raw byte array. Currently it's using the deprecated format as
* default, which will be changed into `ubj` in future releases. * default, which will be changed into `ubj` in future releases.
@ -841,29 +812,6 @@ public class Booster implements Serializable, KryoSerializable {
return bytes[0]; return bytes[0];
} }
/**
* Load the booster model from thread-local rabit checkpoint.
* This is only used in distributed training.
* @return the stored version number of the checkpoint.
* @throws XGBoostError
*/
int loadRabitCheckpoint() throws XGBoostError {
int[] out = new int[1];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
version = out[0];
return version;
}
/**
* Save the booster model into thread-local rabit checkpoint and increment the version.
* This is only used in distributed training.
* @throws XGBoostError
*/
void saveRabitCheckpoint() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
version += 1;
}
/** /**
* Get number of model features. * Get number of model features.
* @return the number of features. * @return the number of features.
@ -874,6 +822,11 @@ public class Booster implements Serializable, KryoSerializable {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, numFeature)); XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, numFeature));
return numFeature[0]; return numFeature[0];
} }
public int getNumBoostedRound() throws XGBoostError {
int[] numRound = new int[1];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumBoostedRound(this.handle, numRound));
return numRound[0];
}
/** /**
* Internal initialization function. * Internal initialization function.

View File

@ -1,3 +1,18 @@
/*
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java; package ml.dmlc.xgboost4j.java;
import java.io.IOException; import java.io.IOException;
@ -15,7 +30,7 @@ public class ExternalCheckpointManager {
private Log logger = LogFactory.getLog("ExternalCheckpointManager"); private Log logger = LogFactory.getLog("ExternalCheckpointManager");
private String modelSuffix = ".model"; private String modelSuffix = ".model";
private Path checkpointPath; private Path checkpointPath; // directory for checkpoints
private FileSystem fs; private FileSystem fs;
public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError { public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
@ -35,6 +50,7 @@ public class ExternalCheckpointManager {
if (!fs.exists(checkpointPath)) { if (!fs.exists(checkpointPath)) {
return new ArrayList<>(); return new ArrayList<>();
} else { } else {
// Get integer versions from a list of checkpoint files.
return Arrays.stream(fs.listStatus(checkpointPath)) return Arrays.stream(fs.listStatus(checkpointPath))
.map(path -> path.getPath().getName()) .map(path -> path.getPath().getName())
.filter(fileName -> fileName.endsWith(modelSuffix)) .filter(fileName -> fileName.endsWith(modelSuffix))
@ -44,6 +60,11 @@ public class ExternalCheckpointManager {
} }
} }
private Integer latest(List<Integer> versions) {
return versions.stream()
.max(Comparator.comparing(Integer::valueOf)).get();
}
public void cleanPath() throws IOException { public void cleanPath() throws IOException {
fs.delete(checkpointPath, true); fs.delete(checkpointPath, true);
} }
@ -51,12 +72,11 @@ public class ExternalCheckpointManager {
public Booster loadCheckpointAsBooster() throws IOException, XGBoostError { public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
List<Integer> versions = getExistingVersions(); List<Integer> versions = getExistingVersions();
if (versions.size() > 0) { if (versions.size() > 0) {
int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get(); int latestVersion = this.latest(versions);
String checkpointPath = getPath(latestVersion); String checkpointPath = getPath(latestVersion);
InputStream in = fs.open(new Path(checkpointPath)); InputStream in = fs.open(new Path(checkpointPath));
logger.info("loaded checkpoint from " + checkpointPath); logger.info("loaded checkpoint from " + checkpointPath);
Booster booster = XGBoost.loadModel(in); Booster booster = XGBoost.loadModel(in);
booster.setVersion(latestVersion);
return booster; return booster;
} else { } else {
return null; return null;
@ -66,12 +86,15 @@ public class ExternalCheckpointManager {
public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError { public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
List<String> prevModelPaths = getExistingVersions().stream() List<String> prevModelPaths = getExistingVersions().stream()
.map(this::getPath).collect(Collectors.toList()); .map(this::getPath).collect(Collectors.toList());
String eventualPath = getPath(boosterToCheckpoint.getVersion()); // checkpointing is done after update, so n_rounds - 1 is the current iteration
// accounting for training continuation.
Integer iter = boosterToCheckpoint.getNumBoostedRound() - 1;
String eventualPath = getPath(iter);
String tempPath = eventualPath + "-" + UUID.randomUUID(); String tempPath = eventualPath + "-" + UUID.randomUUID();
try (OutputStream out = fs.create(new Path(tempPath), true)) { try (OutputStream out = fs.create(new Path(tempPath), true)) {
boosterToCheckpoint.saveModel(out); boosterToCheckpoint.saveModel(out);
fs.rename(new Path(tempPath), new Path(eventualPath)); fs.rename(new Path(tempPath), new Path(eventualPath));
logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion()); logger.info("saving checkpoint with version " + iter);
prevModelPaths.stream().forEach(path -> { prevModelPaths.stream().forEach(path -> {
try { try {
fs.delete(new Path(path), true); fs.delete(new Path(path), true);
@ -83,7 +106,7 @@ public class ExternalCheckpointManager {
} }
public void cleanUpHigherVersions(int currentRound) throws IOException { public void cleanUpHigherVersions(int currentRound) throws IOException {
getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> { getExistingVersions().stream().filter(v -> v > currentRound).forEach(v -> {
try { try {
fs.delete(new Path(getPath(v)), true); fs.delete(new Path(getPath(v)), true);
} catch (IOException e) { } catch (IOException e) {
@ -91,27 +114,26 @@ public class ExternalCheckpointManager {
} }
}); });
} }
// Get a list of iterations that need checkpointing.
public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds) public List<Integer> getCheckpointRounds(
int firstRound, int checkpointInterval, int numOfRounds)
throws IOException { throws IOException {
if (checkpointInterval > 0) { int end = firstRound + numOfRounds; // exclusive
List<Integer> prevRounds = int lastRound = end - 1;
getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList()); if (end - 1 < 0) {
prevRounds.add(0); throw new IllegalArgumentException("Inavlid `numOfRounds`.");
int firstCheckpointRound = prevRounds.stream() }
.max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
List<Integer> arr = new ArrayList<>(); List<Integer> arr = new ArrayList<>();
for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) { if (checkpointInterval > 0) {
for (int i = firstRound; i < end; i += checkpointInterval) {
arr.add(i); arr.add(i);
} }
arr.add(numOfRounds);
return arr;
} else if (checkpointInterval <= 0) {
List<Integer> l = new ArrayList<Integer>();
l.add(numOfRounds);
return l;
} else {
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
} }
if (!arr.contains(lastRound)) {
arr.add(lastRound);
}
return arr;
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
Copyright (c) 2014,2021 by Contributors Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -133,7 +133,7 @@ public class XGBoost {
int earlyStoppingRound) throws XGBoostError { int earlyStoppingRound) throws XGBoostError {
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null); return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
} }
// save checkpoint if iter is in checkpointIterations
private static void saveCheckpoint( private static void saveCheckpoint(
Booster booster, Booster booster,
int iter, int iter,
@ -169,7 +169,6 @@ public class XGBoost {
int bestIteration; int bestIteration;
List<String> names = new ArrayList<String>(); List<String> names = new ArrayList<String>();
List<DMatrix> mats = new ArrayList<DMatrix>(); List<DMatrix> mats = new ArrayList<DMatrix>();
Set<Integer> checkpointIterations = new HashSet<>();
ExternalCheckpointManager ecm = null; ExternalCheckpointManager ecm = null;
if (checkpointPath != null) { if (checkpointPath != null) {
ecm = new ExternalCheckpointManager(checkpointPath, fs); ecm = new ExternalCheckpointManager(checkpointPath, fs);
@ -203,32 +202,30 @@ public class XGBoost {
booster = new Booster(params, allMats); booster = new Booster(params, allMats);
booster.setFeatureNames(dtrain.getFeatureNames()); booster.setFeatureNames(dtrain.getFeatureNames());
booster.setFeatureTypes(dtrain.getFeatureTypes()); booster.setFeatureTypes(dtrain.getFeatureTypes());
booster.loadRabitCheckpoint();
} else { } else {
// Start training on an existing booster // Start training on an existing booster
booster.setParams(params); booster.setParams(params);
} }
Set<Integer> checkpointIterations = new HashSet<>();
if (ecm != null) { if (ecm != null) {
checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds)); checkpointIterations = new HashSet<>(
ecm.getCheckpointRounds(booster.getNumBoostedRound(), checkpointInterval, numRounds));
} }
boolean initial_best_score_flag = false; boolean initial_best_score_flag = false;
boolean max_direction = false; boolean max_direction = false;
// begin to train // begin to train
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) { for (int iter = 0; iter < numRounds; iter++) {
if (booster.getVersion() % 2 == 0) {
if (obj != null) { if (obj != null) {
booster.update(dtrain, obj); booster.update(dtrain, iter, obj);
} else { } else {
booster.update(dtrain, iter); booster.update(dtrain, iter);
} }
saveCheckpoint(booster, iter, checkpointIterations, ecm); saveCheckpoint(booster, iter, checkpointIterations, ecm);
booster.saveRabitCheckpoint();
}
//evaluation // evaluation
if (evalMats.length > 0) { if (evalMats.length > 0) {
float[] metricsOut = new float[evalMats.length]; float[] metricsOut = new float[evalMats.length];
String evalInfo; String evalInfo;
@ -285,7 +282,6 @@ public class XGBoost {
Communicator.communicatorPrint(evalInfo + '\n'); Communicator.communicatorPrint(evalInfo + '\n');
} }
} }
booster.saveRabitCheckpoint();
} }
return booster; return booster;
} }

View File

@ -140,10 +140,11 @@ class XGBoostJNI {
public final static native int XGBoosterGetAttrNames(long handle, String[][] out_strings); public final static native int XGBoosterGetAttrNames(long handle, String[][] out_strings);
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string); public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
public final static native int XGBoosterSetAttr(long handle, String key, String value); public final static native int XGBoosterSetAttr(long handle, String key, String value);
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
public final static native int XGBoosterGetNumFeature(long handle, long[] feature); public final static native int XGBoosterGetNumFeature(long handle, long[] feature);
public final static native int XGBoosterGetNumBoostedRound(long handle, int[] rounds);
// communicator functions // communicator functions
public final static native int CommunicatorInit(String[] args); public final static native int CommunicatorInit(String[] args);
public final static native int CommunicatorFinalize(); public final static native int CommunicatorFinalize();

View File

@ -326,7 +326,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
@throws(classOf[XGBoostError]) @throws(classOf[XGBoostError])
def getNumFeature: Long = booster.getNumFeature def getNumFeature: Long = booster.getNumFeature
def getVersion: Int = booster.getVersion def getNumBoostedRound: Long = booster.getNumBoostedRound
/** /**
* Save model into a raw byte array. Available options are "json", "ubj" and "deprecated". * Save model into a raw byte array. Available options are "json", "ubj" and "deprecated".

View File

@ -984,33 +984,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr
return ret; return ret;
} }
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadRabitCheckpoint
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint
(JNIEnv *jenv , jclass jcls, jlong jhandle, jintArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
int version;
int ret = XGBoosterLoadRabitCheckpoint(handle, &version);
JVM_CHECK_CALL(ret);
jint jversion = version;
jenv->SetIntArrayRegion(jout, 0, 1, &jversion);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSaveRabitCheckpoint
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
BoosterHandle handle = (BoosterHandle) jhandle;
return XGBoosterSaveRabitCheckpoint(handle);
}
/* /*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetNumFeature * Method: XGBoosterGetNumFeature
@ -1027,6 +1000,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFea
return ret; return ret;
} }
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoostedRound(
JNIEnv *jenv, jclass, jlong jhandle, jintArray jout) {
BoosterHandle handle = (BoosterHandle)jhandle;
std::int32_t n_rounds{0};
auto ret = XGBoosterBoostedRounds(handle, &n_rounds);
JVM_CHECK_CALL(ret);
jint jn_rounds = n_rounds;
jenv->SetIntArrayRegion(jout, 0, 1, &jn_rounds);
return ret;
}
/* /*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorInit * Method: CommunicatorInit

View File

@ -287,22 +287,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr
(JNIEnv *, jclass, jlong, jstring, jstring); (JNIEnv *, jclass, jlong, jstring, jstring);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadRabitCheckpoint
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSaveRabitCheckpoint
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
(JNIEnv *, jclass, jlong);
/* /*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetNumFeature * Method: XGBoosterGetNumFeature
@ -311,6 +295,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabit
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature
(JNIEnv *, jclass, jlong, jlongArray); (JNIEnv *, jclass, jlong, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetNumBoostedRound
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoostedRound
(JNIEnv *, jclass, jlong, jintArray);
/* /*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorInit * Method: CommunicatorInit

View File

@ -1,5 +1,5 @@
/* /*
Copyright (c) 2014-2022 by Contributors Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -16,7 +16,6 @@
package ml.dmlc.xgboost4j.java; package ml.dmlc.xgboost4j.java;
import junit.framework.TestCase; import junit.framework.TestCase;
import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
@ -845,14 +844,12 @@ public class BoosterImplTest {
float tempBoosterError = eval.eval(tempBooster.predict(testMat, true, 0), testMat); float tempBoosterError = eval.eval(tempBooster.predict(testMat, true, 0), testMat);
// Save tempBooster to bytestream and load back // Save tempBooster to bytestream and load back
int prevVersion = tempBooster.getVersion();
ByteArrayInputStream in = new ByteArrayInputStream(tempBooster.toByteArray()); ByteArrayInputStream in = new ByteArrayInputStream(tempBooster.toByteArray());
tempBooster = XGBoost.loadModel(in); tempBooster = XGBoost.loadModel(in);
in.close(); in.close();
tempBooster.setVersion(prevVersion);
// Continue training using tempBooster // Continue training using tempBooster
round = 4; round = 2;
Booster booster2 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0, tempBooster); Booster booster2 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0, tempBooster);
float booster2error = eval.eval(booster2.predict(testMat, true, 0), testMat); float booster2error = eval.eval(booster2.predict(testMat, true, 0), testMat);
TestCase.assertTrue(booster1error == booster2error); TestCase.assertTrue(booster1error == booster2error);

View File

@ -540,7 +540,10 @@ class EvaluationMonitor(TrainingCallback):
class TrainingCheckPoint(TrainingCallback): class TrainingCheckPoint(TrainingCallback):
"""Checkpointing operation. """Checkpointing operation. Users are encouraged to create their own callbacks for
checkpoint as XGBoost doesn't handle distributed file systems. When checkpointing on
distributed systems, be sure to know the rank of the worker to avoid multiple
workers checkpointing to the same place.
.. versionadded:: 1.3.0 .. versionadded:: 1.3.0
@ -553,9 +556,9 @@ class TrainingCheckPoint(TrainingCallback):
pattern of output model file. Models will be saved as name_0.json, name_1.json, pattern of output model file. Models will be saved as name_0.json, name_1.json,
name_2.json .... name_2.json ....
as_pickle : as_pickle :
When set to True, all training parameters will be saved in pickle format, instead When set to True, all training parameters will be saved in pickle format,
of saving only the model. instead of saving only the model.
iterations : interval :
Interval of checkpointing. Checkpointing is slow so setting a larger number can Interval of checkpointing. Checkpointing is slow so setting a larger number can
reduce performance hit. reduce performance hit.
@ -566,15 +569,20 @@ class TrainingCheckPoint(TrainingCallback):
directory: Union[str, os.PathLike], directory: Union[str, os.PathLike],
name: str = "model", name: str = "model",
as_pickle: bool = False, as_pickle: bool = False,
iterations: int = 100, interval: int = 100,
) -> None: ) -> None:
self._path = os.fspath(directory) self._path = os.fspath(directory)
self._name = name self._name = name
self._as_pickle = as_pickle self._as_pickle = as_pickle
self._iterations = iterations self._iterations = interval
self._epoch = 0 self._epoch = 0 # counter for iterval
self._start = 0 # beginning iteration
super().__init__() super().__init__()
def before_training(self, model: _Model) -> _Model:
self._start = model.num_boosted_rounds()
return model
def after_iteration( def after_iteration(
self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog
) -> bool: ) -> bool:
@ -583,11 +591,12 @@ class TrainingCheckPoint(TrainingCallback):
self._path, self._path,
self._name self._name
+ "_" + "_"
+ str(epoch) + (str(epoch + self._start))
+ (".pkl" if self._as_pickle else ".json"), + (".pkl" if self._as_pickle else ".json"),
) )
self._epoch = 0 self._epoch = 0 # reset counter
if collective.get_rank() == 0: if collective.get_rank() == 0:
# checkpoint using the first worker
if self._as_pickle: if self._as_pickle:
with open(path, "wb") as fd: with open(path, "wb") as fd:
pickle.dump(model, fd) pickle.dump(model, fd)

View File

@ -1430,36 +1430,13 @@ XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle,
API_END(); API_END();
} }
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle, XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer, int end_layer, int step,
int* version) {
API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Learner*>(handle);
xgboost_CHECK_C_ARG_PTR(version);
*version = rabit::LoadCheckPoint();
if (*version != 0) {
bst->Configure();
}
API_END();
}
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
API_BEGIN();
CHECK_HANDLE();
auto *learner = static_cast<Learner *>(handle);
learner->Configure();
rabit::CheckPoint();
API_END();
}
XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer,
int end_layer, int step,
BoosterHandle *out) { BoosterHandle *out) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
xgboost_CHECK_C_ARG_PTR(out); xgboost_CHECK_C_ARG_PTR(out);
auto* learner = static_cast<Learner*>(handle); auto *learner = static_cast<Learner *>(handle);
bool out_of_bound = false; bool out_of_bound = false;
auto p_out = learner->Slice(begin_layer, end_layer, step, &out_of_bound); auto p_out = learner->Slice(begin_layer, end_layer, step, &out_of_bound);
if (out_of_bound) { if (out_of_bound) {

View File

@ -443,7 +443,7 @@ class TestCallbacks:
m = xgb.DMatrix(X, y) m = xgb.DMatrix(X, y)
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
check_point = xgb.callback.TrainingCheckPoint( check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, iterations=1, name="model" directory=tmpdir, interval=1, name="model"
) )
xgb.train( xgb.train(
{"objective": "binary:logistic"}, {"objective": "binary:logistic"},
@ -456,7 +456,7 @@ class TestCallbacks:
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".json")) assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".json"))
check_point = xgb.callback.TrainingCheckPoint( check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, iterations=1, as_pickle=True, name="model" directory=tmpdir, interval=1, as_pickle=True, name="model"
) )
xgb.train( xgb.train(
{"objective": "binary:logistic"}, {"objective": "binary:logistic"},

View File

@ -2238,7 +2238,7 @@ class TestDaskCallbacks:
y, y,
callbacks=[ callbacks=[
xgb.callback.TrainingCheckPoint( xgb.callback.TrainingCheckPoint(
directory=Path(tmpdir), iterations=1, name="model" directory=Path(tmpdir), interval=1, name="model"
) )
], ],
) )