[breaking] [jvm-packages] Remove rabit check point. (#9599)
- Add `numBoostedRound` to jvm packages - Remove rabit checkpoint version. - Change the starting version of training continuation in JVM [breaking]. - Redefine the checkpoint version policy in jvm package. [breaking] - Rename the Python check point callback parameter. [breaking] - Unifies the checkpoint policy between Python and JVM.
This commit is contained in:
parent
7901a299b2
commit
c75a3bc0a9
@ -104,7 +104,7 @@ def check_point_callback():
|
||||
# Use callback class from xgboost.callback
|
||||
# Feel free to subclass/customize it to suit your need.
|
||||
check_point = xgb.callback.TrainingCheckPoint(
|
||||
directory=tmpdir, iterations=rounds, name="model"
|
||||
directory=tmpdir, interval=rounds, name="model"
|
||||
)
|
||||
xgb.train(
|
||||
{"objective": "binary:logistic"},
|
||||
@ -118,7 +118,7 @@ def check_point_callback():
|
||||
# This version of checkpoint saves everything including parameters and
|
||||
# model. See: doc/tutorials/saving_model.rst
|
||||
check_point = xgb.callback.TrainingCheckPoint(
|
||||
directory=tmpdir, iterations=rounds, as_pickle=True, name="model"
|
||||
directory=tmpdir, interval=rounds, as_pickle=True, name="model"
|
||||
)
|
||||
xgb.train(
|
||||
{"objective": "binary:logistic"},
|
||||
|
||||
@ -1308,24 +1308,6 @@ XGB_DLL int XGBoosterSerializeToBuffer(BoosterHandle handle, bst_ulong *out_len,
|
||||
XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle,
|
||||
const void *buf, bst_ulong len);
|
||||
|
||||
/*!
|
||||
* \brief Initialize the booster from rabit checkpoint.
|
||||
* This is used in distributed training API.
|
||||
* \param handle handle
|
||||
* \param version The output version of the model.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
|
||||
int* version);
|
||||
|
||||
/*!
|
||||
* \brief Save the current checkpoint to rabit.
|
||||
* \param handle handle
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle);
|
||||
|
||||
|
||||
/*!
|
||||
* \brief Save XGBoost's internal configuration into a JSON document. Currently the
|
||||
* support is experimental, function signature may change in the future without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -32,57 +32,53 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
|
||||
}
|
||||
|
||||
private def createNewModels():
|
||||
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
|
||||
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
|
||||
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
||||
val (model4, model8) = {
|
||||
val (model2, model4) = {
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val paramMap = produceParamMap(tmpPath, 2)
|
||||
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
|
||||
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
|
||||
}
|
||||
(tmpPath, model4, model8)
|
||||
(tmpPath, model2, model4)
|
||||
}
|
||||
|
||||
test("test update/load models") {
|
||||
val (tmpPath, model4, model8) = createNewModels()
|
||||
val (tmpPath, model2, model4) = createNewModels()
|
||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||
|
||||
manager.updateCheckpoint(model4._booster.booster)
|
||||
manager.updateCheckpoint(model2._booster.booster)
|
||||
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "4.model")
|
||||
assert(manager.loadCheckpointAsScalaBooster().getVersion == 4)
|
||||
assert(files.head.getPath.getName == "1.model")
|
||||
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 2)
|
||||
|
||||
manager.updateCheckpoint(model8._booster)
|
||||
manager.updateCheckpoint(model4._booster)
|
||||
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "8.model")
|
||||
assert(manager.loadCheckpointAsScalaBooster().getVersion == 8)
|
||||
assert(files.head.getPath.getName == "3.model")
|
||||
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4)
|
||||
}
|
||||
|
||||
test("test cleanUpHigherVersions") {
|
||||
val (tmpPath, model4, model8) = createNewModels()
|
||||
val (tmpPath, model2, model4) = createNewModels()
|
||||
|
||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||
manager.updateCheckpoint(model8._booster)
|
||||
manager.cleanUpHigherVersions(8)
|
||||
assert(new File(s"$tmpPath/8.model").exists())
|
||||
manager.updateCheckpoint(model4._booster)
|
||||
manager.cleanUpHigherVersions(3)
|
||||
assert(new File(s"$tmpPath/3.model").exists())
|
||||
|
||||
manager.cleanUpHigherVersions(4)
|
||||
assert(!new File(s"$tmpPath/8.model").exists())
|
||||
manager.cleanUpHigherVersions(2)
|
||||
assert(!new File(s"$tmpPath/3.model").exists())
|
||||
}
|
||||
|
||||
test("test checkpoint rounds") {
|
||||
import scala.collection.JavaConverters._
|
||||
val (tmpPath, model4, model8) = createNewModels()
|
||||
val (tmpPath, model2, model4) = createNewModels()
|
||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||
assertResult(Seq(7))(
|
||||
manager.getCheckpointRounds(0, 7).asScala)
|
||||
assertResult(Seq(2, 4, 6, 7))(
|
||||
manager.getCheckpointRounds(2, 7).asScala)
|
||||
manager.updateCheckpoint(model4._booster)
|
||||
assertResult(Seq(4, 6, 7))(
|
||||
manager.getCheckpointRounds(2, 7).asScala)
|
||||
assertResult(Seq(2))(manager.getCheckpointRounds(0, 0, 3).asScala)
|
||||
assertResult(Seq(0, 2, 4, 6))(manager.getCheckpointRounds(0, 2, 7).asScala)
|
||||
assertResult(Seq(0, 2, 4, 6, 7))(manager.getCheckpointRounds(0, 2, 8).asScala)
|
||||
}
|
||||
|
||||
|
||||
@ -109,8 +105,8 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
|
||||
// Check only one model is kept after training
|
||||
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "8.model")
|
||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
|
||||
assert(files.head.getPath.getName == "4.model")
|
||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.model")
|
||||
// Train next model based on prev model
|
||||
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
||||
assert(error(tmpModel) >= error(prevModel._booster))
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -787,35 +787,6 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
return importanceMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the model as byte array representation.
|
||||
* Write these bytes to a file will give compatible format with other xgboost bindings.
|
||||
*
|
||||
* If java natively support HDFS file API, use toByteArray and write the ByteArray
|
||||
*
|
||||
* @param withStats Controls whether the split statistics are output.
|
||||
* @return dumped model information
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
private String[] getDumpInfo(boolean withStats) throws XGBoostError {
|
||||
int statsFlag = 0;
|
||||
if (withStats) {
|
||||
statsFlag = 1;
|
||||
}
|
||||
String[][] modelInfos = new String[1][];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelEx(handle, "", statsFlag, "text",
|
||||
modelInfos));
|
||||
return modelInfos[0];
|
||||
}
|
||||
|
||||
public int getVersion() {
|
||||
return this.version;
|
||||
}
|
||||
|
||||
public void setVersion(int version) {
|
||||
this.version = version;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save model into raw byte array. Currently it's using the deprecated format as
|
||||
* default, which will be changed into `ubj` in future releases.
|
||||
@ -841,29 +812,6 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
return bytes[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the booster model from thread-local rabit checkpoint.
|
||||
* This is only used in distributed training.
|
||||
* @return the stored version number of the checkpoint.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
int loadRabitCheckpoint() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
|
||||
version = out[0];
|
||||
return version;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the booster model into thread-local rabit checkpoint and increment the version.
|
||||
* This is only used in distributed training.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
void saveRabitCheckpoint() throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
|
||||
version += 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get number of model features.
|
||||
* @return the number of features.
|
||||
@ -874,6 +822,11 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, numFeature));
|
||||
return numFeature[0];
|
||||
}
|
||||
public int getNumBoostedRound() throws XGBoostError {
|
||||
int[] numRound = new int[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumBoostedRound(this.handle, numRound));
|
||||
return numRound[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal initialization function.
|
||||
|
||||
@ -1,3 +1,18 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.IOException;
|
||||
@ -15,7 +30,7 @@ public class ExternalCheckpointManager {
|
||||
|
||||
private Log logger = LogFactory.getLog("ExternalCheckpointManager");
|
||||
private String modelSuffix = ".model";
|
||||
private Path checkpointPath;
|
||||
private Path checkpointPath; // directory for checkpoints
|
||||
private FileSystem fs;
|
||||
|
||||
public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
|
||||
@ -35,6 +50,7 @@ public class ExternalCheckpointManager {
|
||||
if (!fs.exists(checkpointPath)) {
|
||||
return new ArrayList<>();
|
||||
} else {
|
||||
// Get integer versions from a list of checkpoint files.
|
||||
return Arrays.stream(fs.listStatus(checkpointPath))
|
||||
.map(path -> path.getPath().getName())
|
||||
.filter(fileName -> fileName.endsWith(modelSuffix))
|
||||
@ -44,6 +60,11 @@ public class ExternalCheckpointManager {
|
||||
}
|
||||
}
|
||||
|
||||
private Integer latest(List<Integer> versions) {
|
||||
return versions.stream()
|
||||
.max(Comparator.comparing(Integer::valueOf)).get();
|
||||
}
|
||||
|
||||
public void cleanPath() throws IOException {
|
||||
fs.delete(checkpointPath, true);
|
||||
}
|
||||
@ -51,12 +72,11 @@ public class ExternalCheckpointManager {
|
||||
public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
|
||||
List<Integer> versions = getExistingVersions();
|
||||
if (versions.size() > 0) {
|
||||
int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
|
||||
int latestVersion = this.latest(versions);
|
||||
String checkpointPath = getPath(latestVersion);
|
||||
InputStream in = fs.open(new Path(checkpointPath));
|
||||
logger.info("loaded checkpoint from " + checkpointPath);
|
||||
Booster booster = XGBoost.loadModel(in);
|
||||
booster.setVersion(latestVersion);
|
||||
return booster;
|
||||
} else {
|
||||
return null;
|
||||
@ -65,13 +85,16 @@ public class ExternalCheckpointManager {
|
||||
|
||||
public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
|
||||
List<String> prevModelPaths = getExistingVersions().stream()
|
||||
.map(this::getPath).collect(Collectors.toList());
|
||||
String eventualPath = getPath(boosterToCheckpoint.getVersion());
|
||||
.map(this::getPath).collect(Collectors.toList());
|
||||
// checkpointing is done after update, so n_rounds - 1 is the current iteration
|
||||
// accounting for training continuation.
|
||||
Integer iter = boosterToCheckpoint.getNumBoostedRound() - 1;
|
||||
String eventualPath = getPath(iter);
|
||||
String tempPath = eventualPath + "-" + UUID.randomUUID();
|
||||
try (OutputStream out = fs.create(new Path(tempPath), true)) {
|
||||
boosterToCheckpoint.saveModel(out);
|
||||
fs.rename(new Path(tempPath), new Path(eventualPath));
|
||||
logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
|
||||
logger.info("saving checkpoint with version " + iter);
|
||||
prevModelPaths.stream().forEach(path -> {
|
||||
try {
|
||||
fs.delete(new Path(path), true);
|
||||
@ -83,7 +106,7 @@ public class ExternalCheckpointManager {
|
||||
}
|
||||
|
||||
public void cleanUpHigherVersions(int currentRound) throws IOException {
|
||||
getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
|
||||
getExistingVersions().stream().filter(v -> v > currentRound).forEach(v -> {
|
||||
try {
|
||||
fs.delete(new Path(getPath(v)), true);
|
||||
} catch (IOException e) {
|
||||
@ -91,27 +114,26 @@ public class ExternalCheckpointManager {
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds)
|
||||
// Get a list of iterations that need checkpointing.
|
||||
public List<Integer> getCheckpointRounds(
|
||||
int firstRound, int checkpointInterval, int numOfRounds)
|
||||
throws IOException {
|
||||
int end = firstRound + numOfRounds; // exclusive
|
||||
int lastRound = end - 1;
|
||||
if (end - 1 < 0) {
|
||||
throw new IllegalArgumentException("Inavlid `numOfRounds`.");
|
||||
}
|
||||
|
||||
List<Integer> arr = new ArrayList<>();
|
||||
if (checkpointInterval > 0) {
|
||||
List<Integer> prevRounds =
|
||||
getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
|
||||
prevRounds.add(0);
|
||||
int firstCheckpointRound = prevRounds.stream()
|
||||
.max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
|
||||
List<Integer> arr = new ArrayList<>();
|
||||
for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) {
|
||||
for (int i = firstRound; i < end; i += checkpointInterval) {
|
||||
arr.add(i);
|
||||
}
|
||||
arr.add(numOfRounds);
|
||||
return arr;
|
||||
} else if (checkpointInterval <= 0) {
|
||||
List<Integer> l = new ArrayList<Integer>();
|
||||
l.add(numOfRounds);
|
||||
return l;
|
||||
} else {
|
||||
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
|
||||
}
|
||||
|
||||
if (!arr.contains(lastRound)) {
|
||||
arr.add(lastRound);
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014,2021 by Contributors
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -133,7 +133,7 @@ public class XGBoost {
|
||||
int earlyStoppingRound) throws XGBoostError {
|
||||
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
|
||||
}
|
||||
|
||||
// save checkpoint if iter is in checkpointIterations
|
||||
private static void saveCheckpoint(
|
||||
Booster booster,
|
||||
int iter,
|
||||
@ -169,7 +169,6 @@ public class XGBoost {
|
||||
int bestIteration;
|
||||
List<String> names = new ArrayList<String>();
|
||||
List<DMatrix> mats = new ArrayList<DMatrix>();
|
||||
Set<Integer> checkpointIterations = new HashSet<>();
|
||||
ExternalCheckpointManager ecm = null;
|
||||
if (checkpointPath != null) {
|
||||
ecm = new ExternalCheckpointManager(checkpointPath, fs);
|
||||
@ -203,32 +202,30 @@ public class XGBoost {
|
||||
booster = new Booster(params, allMats);
|
||||
booster.setFeatureNames(dtrain.getFeatureNames());
|
||||
booster.setFeatureTypes(dtrain.getFeatureTypes());
|
||||
booster.loadRabitCheckpoint();
|
||||
} else {
|
||||
// Start training on an existing booster
|
||||
booster.setParams(params);
|
||||
}
|
||||
|
||||
Set<Integer> checkpointIterations = new HashSet<>();
|
||||
if (ecm != null) {
|
||||
checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
|
||||
checkpointIterations = new HashSet<>(
|
||||
ecm.getCheckpointRounds(booster.getNumBoostedRound(), checkpointInterval, numRounds));
|
||||
}
|
||||
|
||||
boolean initial_best_score_flag = false;
|
||||
boolean max_direction = false;
|
||||
|
||||
// begin to train
|
||||
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
|
||||
if (booster.getVersion() % 2 == 0) {
|
||||
if (obj != null) {
|
||||
booster.update(dtrain, obj);
|
||||
} else {
|
||||
booster.update(dtrain, iter);
|
||||
}
|
||||
saveCheckpoint(booster, iter, checkpointIterations, ecm);
|
||||
booster.saveRabitCheckpoint();
|
||||
for (int iter = 0; iter < numRounds; iter++) {
|
||||
if (obj != null) {
|
||||
booster.update(dtrain, iter, obj);
|
||||
} else {
|
||||
booster.update(dtrain, iter);
|
||||
}
|
||||
saveCheckpoint(booster, iter, checkpointIterations, ecm);
|
||||
|
||||
//evaluation
|
||||
// evaluation
|
||||
if (evalMats.length > 0) {
|
||||
float[] metricsOut = new float[evalMats.length];
|
||||
String evalInfo;
|
||||
@ -285,7 +282,6 @@ public class XGBoost {
|
||||
Communicator.communicatorPrint(evalInfo + '\n');
|
||||
}
|
||||
}
|
||||
booster.saveRabitCheckpoint();
|
||||
}
|
||||
return booster;
|
||||
}
|
||||
|
||||
@ -140,10 +140,11 @@ class XGBoostJNI {
|
||||
public final static native int XGBoosterGetAttrNames(long handle, String[][] out_strings);
|
||||
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
|
||||
public final static native int XGBoosterSetAttr(long handle, String key, String value);
|
||||
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
|
||||
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
|
||||
|
||||
public final static native int XGBoosterGetNumFeature(long handle, long[] feature);
|
||||
|
||||
public final static native int XGBoosterGetNumBoostedRound(long handle, int[] rounds);
|
||||
|
||||
// communicator functions
|
||||
public final static native int CommunicatorInit(String[] args);
|
||||
public final static native int CommunicatorFinalize();
|
||||
|
||||
@ -326,7 +326,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
@throws(classOf[XGBoostError])
|
||||
def getNumFeature: Long = booster.getNumFeature
|
||||
|
||||
def getVersion: Int = booster.getVersion
|
||||
def getNumBoostedRound: Long = booster.getNumBoostedRound
|
||||
|
||||
/**
|
||||
* Save model into a raw byte array. Available options are "json", "ubj" and "deprecated".
|
||||
|
||||
@ -984,33 +984,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterLoadRabitCheckpoint
|
||||
* Signature: (J[I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint
|
||||
(JNIEnv *jenv , jclass jcls, jlong jhandle, jintArray jout) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
int version;
|
||||
int ret = XGBoosterLoadRabitCheckpoint(handle, &version);
|
||||
JVM_CHECK_CALL(ret);
|
||||
jint jversion = version;
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &jversion);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSaveRabitCheckpoint
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
return XGBoosterSaveRabitCheckpoint(handle);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetNumFeature
|
||||
@ -1027,6 +1000,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFea
|
||||
return ret;
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoostedRound(
|
||||
JNIEnv *jenv, jclass, jlong jhandle, jintArray jout) {
|
||||
BoosterHandle handle = (BoosterHandle)jhandle;
|
||||
std::int32_t n_rounds{0};
|
||||
auto ret = XGBoosterBoostedRounds(handle, &n_rounds);
|
||||
JVM_CHECK_CALL(ret);
|
||||
jint jn_rounds = n_rounds;
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &jn_rounds);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorInit
|
||||
|
||||
@ -287,22 +287,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr
|
||||
(JNIEnv *, jclass, jlong, jstring, jstring);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterLoadRabitCheckpoint
|
||||
* Signature: (J[I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSaveRabitCheckpoint
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetNumFeature
|
||||
@ -311,6 +295,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabit
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature
|
||||
(JNIEnv *, jclass, jlong, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetNumBoostedRound
|
||||
* Signature: (J[I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoostedRound
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorInit
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -16,7 +16,6 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
@ -31,7 +30,7 @@ import static org.junit.Assert.fail;
|
||||
|
||||
/**
|
||||
* test cases for Booster Inplace Predict
|
||||
*
|
||||
*
|
||||
* @author hzx and Sovrn
|
||||
*/
|
||||
public class BoosterImplTest {
|
||||
@ -845,14 +844,12 @@ public class BoosterImplTest {
|
||||
float tempBoosterError = eval.eval(tempBooster.predict(testMat, true, 0), testMat);
|
||||
|
||||
// Save tempBooster to bytestream and load back
|
||||
int prevVersion = tempBooster.getVersion();
|
||||
ByteArrayInputStream in = new ByteArrayInputStream(tempBooster.toByteArray());
|
||||
tempBooster = XGBoost.loadModel(in);
|
||||
in.close();
|
||||
tempBooster.setVersion(prevVersion);
|
||||
|
||||
// Continue training using tempBooster
|
||||
round = 4;
|
||||
round = 2;
|
||||
Booster booster2 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0, tempBooster);
|
||||
float booster2error = eval.eval(booster2.predict(testMat, true, 0), testMat);
|
||||
TestCase.assertTrue(booster1error == booster2error);
|
||||
|
||||
@ -540,7 +540,10 @@ class EvaluationMonitor(TrainingCallback):
|
||||
|
||||
|
||||
class TrainingCheckPoint(TrainingCallback):
|
||||
"""Checkpointing operation.
|
||||
"""Checkpointing operation. Users are encouraged to create their own callbacks for
|
||||
checkpoint as XGBoost doesn't handle distributed file systems. When checkpointing on
|
||||
distributed systems, be sure to know the rank of the worker to avoid multiple
|
||||
workers checkpointing to the same place.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
@ -553,9 +556,9 @@ class TrainingCheckPoint(TrainingCallback):
|
||||
pattern of output model file. Models will be saved as name_0.json, name_1.json,
|
||||
name_2.json ....
|
||||
as_pickle :
|
||||
When set to True, all training parameters will be saved in pickle format, instead
|
||||
of saving only the model.
|
||||
iterations :
|
||||
When set to True, all training parameters will be saved in pickle format,
|
||||
instead of saving only the model.
|
||||
interval :
|
||||
Interval of checkpointing. Checkpointing is slow so setting a larger number can
|
||||
reduce performance hit.
|
||||
|
||||
@ -566,15 +569,20 @@ class TrainingCheckPoint(TrainingCallback):
|
||||
directory: Union[str, os.PathLike],
|
||||
name: str = "model",
|
||||
as_pickle: bool = False,
|
||||
iterations: int = 100,
|
||||
interval: int = 100,
|
||||
) -> None:
|
||||
self._path = os.fspath(directory)
|
||||
self._name = name
|
||||
self._as_pickle = as_pickle
|
||||
self._iterations = iterations
|
||||
self._epoch = 0
|
||||
self._iterations = interval
|
||||
self._epoch = 0 # counter for iterval
|
||||
self._start = 0 # beginning iteration
|
||||
super().__init__()
|
||||
|
||||
def before_training(self, model: _Model) -> _Model:
|
||||
self._start = model.num_boosted_rounds()
|
||||
return model
|
||||
|
||||
def after_iteration(
|
||||
self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog
|
||||
) -> bool:
|
||||
@ -583,11 +591,12 @@ class TrainingCheckPoint(TrainingCallback):
|
||||
self._path,
|
||||
self._name
|
||||
+ "_"
|
||||
+ str(epoch)
|
||||
+ (str(epoch + self._start))
|
||||
+ (".pkl" if self._as_pickle else ".json"),
|
||||
)
|
||||
self._epoch = 0
|
||||
self._epoch = 0 # reset counter
|
||||
if collective.get_rank() == 0:
|
||||
# checkpoint using the first worker
|
||||
if self._as_pickle:
|
||||
with open(path, "wb") as fd:
|
||||
pickle.dump(model, fd)
|
||||
|
||||
@ -1430,36 +1430,13 @@ XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
|
||||
int* version) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto* bst = static_cast<Learner*>(handle);
|
||||
xgboost_CHECK_C_ARG_PTR(version);
|
||||
*version = rabit::LoadCheckPoint();
|
||||
if (*version != 0) {
|
||||
bst->Configure();
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto *learner = static_cast<Learner *>(handle);
|
||||
learner->Configure();
|
||||
rabit::CheckPoint();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer,
|
||||
int end_layer, int step,
|
||||
XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer, int end_layer, int step,
|
||||
BoosterHandle *out) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
|
||||
auto* learner = static_cast<Learner*>(handle);
|
||||
auto *learner = static_cast<Learner *>(handle);
|
||||
bool out_of_bound = false;
|
||||
auto p_out = learner->Slice(begin_layer, end_layer, step, &out_of_bound);
|
||||
if (out_of_bound) {
|
||||
|
||||
@ -443,7 +443,7 @@ class TestCallbacks:
|
||||
m = xgb.DMatrix(X, y)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
check_point = xgb.callback.TrainingCheckPoint(
|
||||
directory=tmpdir, iterations=1, name="model"
|
||||
directory=tmpdir, interval=1, name="model"
|
||||
)
|
||||
xgb.train(
|
||||
{"objective": "binary:logistic"},
|
||||
@ -456,7 +456,7 @@ class TestCallbacks:
|
||||
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".json"))
|
||||
|
||||
check_point = xgb.callback.TrainingCheckPoint(
|
||||
directory=tmpdir, iterations=1, as_pickle=True, name="model"
|
||||
directory=tmpdir, interval=1, as_pickle=True, name="model"
|
||||
)
|
||||
xgb.train(
|
||||
{"objective": "binary:logistic"},
|
||||
|
||||
@ -2238,7 +2238,7 @@ class TestDaskCallbacks:
|
||||
y,
|
||||
callbacks=[
|
||||
xgb.callback.TrainingCheckPoint(
|
||||
directory=Path(tmpdir), iterations=1, name="model"
|
||||
directory=Path(tmpdir), interval=1, name="model"
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user