[breaking] [jvm-packages] Remove rabit check point. (#9599)
- Add `numBoostedRound` to jvm packages - Remove rabit checkpoint version. - Change the starting version of training continuation in JVM [breaking]. - Redefine the checkpoint version policy in jvm package. [breaking] - Rename the Python check point callback parameter. [breaking] - Unifies the checkpoint policy between Python and JVM.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -32,57 +32,53 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
|
||||
}
|
||||
|
||||
private def createNewModels():
|
||||
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
|
||||
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
|
||||
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
||||
val (model4, model8) = {
|
||||
val (model2, model4) = {
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val paramMap = produceParamMap(tmpPath, 2)
|
||||
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
|
||||
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
|
||||
}
|
||||
(tmpPath, model4, model8)
|
||||
(tmpPath, model2, model4)
|
||||
}
|
||||
|
||||
test("test update/load models") {
|
||||
val (tmpPath, model4, model8) = createNewModels()
|
||||
val (tmpPath, model2, model4) = createNewModels()
|
||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||
|
||||
manager.updateCheckpoint(model4._booster.booster)
|
||||
manager.updateCheckpoint(model2._booster.booster)
|
||||
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "4.model")
|
||||
assert(manager.loadCheckpointAsScalaBooster().getVersion == 4)
|
||||
assert(files.head.getPath.getName == "1.model")
|
||||
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 2)
|
||||
|
||||
manager.updateCheckpoint(model8._booster)
|
||||
manager.updateCheckpoint(model4._booster)
|
||||
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "8.model")
|
||||
assert(manager.loadCheckpointAsScalaBooster().getVersion == 8)
|
||||
assert(files.head.getPath.getName == "3.model")
|
||||
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4)
|
||||
}
|
||||
|
||||
test("test cleanUpHigherVersions") {
|
||||
val (tmpPath, model4, model8) = createNewModels()
|
||||
val (tmpPath, model2, model4) = createNewModels()
|
||||
|
||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||
manager.updateCheckpoint(model8._booster)
|
||||
manager.cleanUpHigherVersions(8)
|
||||
assert(new File(s"$tmpPath/8.model").exists())
|
||||
manager.updateCheckpoint(model4._booster)
|
||||
manager.cleanUpHigherVersions(3)
|
||||
assert(new File(s"$tmpPath/3.model").exists())
|
||||
|
||||
manager.cleanUpHigherVersions(4)
|
||||
assert(!new File(s"$tmpPath/8.model").exists())
|
||||
manager.cleanUpHigherVersions(2)
|
||||
assert(!new File(s"$tmpPath/3.model").exists())
|
||||
}
|
||||
|
||||
test("test checkpoint rounds") {
|
||||
import scala.collection.JavaConverters._
|
||||
val (tmpPath, model4, model8) = createNewModels()
|
||||
val (tmpPath, model2, model4) = createNewModels()
|
||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||
assertResult(Seq(7))(
|
||||
manager.getCheckpointRounds(0, 7).asScala)
|
||||
assertResult(Seq(2, 4, 6, 7))(
|
||||
manager.getCheckpointRounds(2, 7).asScala)
|
||||
manager.updateCheckpoint(model4._booster)
|
||||
assertResult(Seq(4, 6, 7))(
|
||||
manager.getCheckpointRounds(2, 7).asScala)
|
||||
assertResult(Seq(2))(manager.getCheckpointRounds(0, 0, 3).asScala)
|
||||
assertResult(Seq(0, 2, 4, 6))(manager.getCheckpointRounds(0, 2, 7).asScala)
|
||||
assertResult(Seq(0, 2, 4, 6, 7))(manager.getCheckpointRounds(0, 2, 8).asScala)
|
||||
}
|
||||
|
||||
|
||||
@@ -109,8 +105,8 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
|
||||
// Check only one model is kept after training
|
||||
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "8.model")
|
||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
|
||||
assert(files.head.getPath.getName == "4.model")
|
||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.model")
|
||||
// Train next model based on prev model
|
||||
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
||||
assert(error(tmpModel) >= error(prevModel._booster))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -787,35 +787,6 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
return importanceMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the model as byte array representation.
|
||||
* Write these bytes to a file will give compatible format with other xgboost bindings.
|
||||
*
|
||||
* If java natively support HDFS file API, use toByteArray and write the ByteArray
|
||||
*
|
||||
* @param withStats Controls whether the split statistics are output.
|
||||
* @return dumped model information
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
private String[] getDumpInfo(boolean withStats) throws XGBoostError {
|
||||
int statsFlag = 0;
|
||||
if (withStats) {
|
||||
statsFlag = 1;
|
||||
}
|
||||
String[][] modelInfos = new String[1][];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelEx(handle, "", statsFlag, "text",
|
||||
modelInfos));
|
||||
return modelInfos[0];
|
||||
}
|
||||
|
||||
public int getVersion() {
|
||||
return this.version;
|
||||
}
|
||||
|
||||
public void setVersion(int version) {
|
||||
this.version = version;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save model into raw byte array. Currently it's using the deprecated format as
|
||||
* default, which will be changed into `ubj` in future releases.
|
||||
@@ -841,29 +812,6 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
return bytes[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the booster model from thread-local rabit checkpoint.
|
||||
* This is only used in distributed training.
|
||||
* @return the stored version number of the checkpoint.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
int loadRabitCheckpoint() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
|
||||
version = out[0];
|
||||
return version;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the booster model into thread-local rabit checkpoint and increment the version.
|
||||
* This is only used in distributed training.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
void saveRabitCheckpoint() throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
|
||||
version += 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get number of model features.
|
||||
* @return the number of features.
|
||||
@@ -874,6 +822,11 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, numFeature));
|
||||
return numFeature[0];
|
||||
}
|
||||
public int getNumBoostedRound() throws XGBoostError {
|
||||
int[] numRound = new int[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumBoostedRound(this.handle, numRound));
|
||||
return numRound[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal initialization function.
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.IOException;
|
||||
@@ -15,7 +30,7 @@ public class ExternalCheckpointManager {
|
||||
|
||||
private Log logger = LogFactory.getLog("ExternalCheckpointManager");
|
||||
private String modelSuffix = ".model";
|
||||
private Path checkpointPath;
|
||||
private Path checkpointPath; // directory for checkpoints
|
||||
private FileSystem fs;
|
||||
|
||||
public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
|
||||
@@ -35,6 +50,7 @@ public class ExternalCheckpointManager {
|
||||
if (!fs.exists(checkpointPath)) {
|
||||
return new ArrayList<>();
|
||||
} else {
|
||||
// Get integer versions from a list of checkpoint files.
|
||||
return Arrays.stream(fs.listStatus(checkpointPath))
|
||||
.map(path -> path.getPath().getName())
|
||||
.filter(fileName -> fileName.endsWith(modelSuffix))
|
||||
@@ -44,6 +60,11 @@ public class ExternalCheckpointManager {
|
||||
}
|
||||
}
|
||||
|
||||
private Integer latest(List<Integer> versions) {
|
||||
return versions.stream()
|
||||
.max(Comparator.comparing(Integer::valueOf)).get();
|
||||
}
|
||||
|
||||
public void cleanPath() throws IOException {
|
||||
fs.delete(checkpointPath, true);
|
||||
}
|
||||
@@ -51,12 +72,11 @@ public class ExternalCheckpointManager {
|
||||
public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
|
||||
List<Integer> versions = getExistingVersions();
|
||||
if (versions.size() > 0) {
|
||||
int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
|
||||
int latestVersion = this.latest(versions);
|
||||
String checkpointPath = getPath(latestVersion);
|
||||
InputStream in = fs.open(new Path(checkpointPath));
|
||||
logger.info("loaded checkpoint from " + checkpointPath);
|
||||
Booster booster = XGBoost.loadModel(in);
|
||||
booster.setVersion(latestVersion);
|
||||
return booster;
|
||||
} else {
|
||||
return null;
|
||||
@@ -65,13 +85,16 @@ public class ExternalCheckpointManager {
|
||||
|
||||
public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
|
||||
List<String> prevModelPaths = getExistingVersions().stream()
|
||||
.map(this::getPath).collect(Collectors.toList());
|
||||
String eventualPath = getPath(boosterToCheckpoint.getVersion());
|
||||
.map(this::getPath).collect(Collectors.toList());
|
||||
// checkpointing is done after update, so n_rounds - 1 is the current iteration
|
||||
// accounting for training continuation.
|
||||
Integer iter = boosterToCheckpoint.getNumBoostedRound() - 1;
|
||||
String eventualPath = getPath(iter);
|
||||
String tempPath = eventualPath + "-" + UUID.randomUUID();
|
||||
try (OutputStream out = fs.create(new Path(tempPath), true)) {
|
||||
boosterToCheckpoint.saveModel(out);
|
||||
fs.rename(new Path(tempPath), new Path(eventualPath));
|
||||
logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
|
||||
logger.info("saving checkpoint with version " + iter);
|
||||
prevModelPaths.stream().forEach(path -> {
|
||||
try {
|
||||
fs.delete(new Path(path), true);
|
||||
@@ -83,7 +106,7 @@ public class ExternalCheckpointManager {
|
||||
}
|
||||
|
||||
public void cleanUpHigherVersions(int currentRound) throws IOException {
|
||||
getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
|
||||
getExistingVersions().stream().filter(v -> v > currentRound).forEach(v -> {
|
||||
try {
|
||||
fs.delete(new Path(getPath(v)), true);
|
||||
} catch (IOException e) {
|
||||
@@ -91,27 +114,26 @@ public class ExternalCheckpointManager {
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds)
|
||||
// Get a list of iterations that need checkpointing.
|
||||
public List<Integer> getCheckpointRounds(
|
||||
int firstRound, int checkpointInterval, int numOfRounds)
|
||||
throws IOException {
|
||||
int end = firstRound + numOfRounds; // exclusive
|
||||
int lastRound = end - 1;
|
||||
if (end - 1 < 0) {
|
||||
throw new IllegalArgumentException("Inavlid `numOfRounds`.");
|
||||
}
|
||||
|
||||
List<Integer> arr = new ArrayList<>();
|
||||
if (checkpointInterval > 0) {
|
||||
List<Integer> prevRounds =
|
||||
getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
|
||||
prevRounds.add(0);
|
||||
int firstCheckpointRound = prevRounds.stream()
|
||||
.max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
|
||||
List<Integer> arr = new ArrayList<>();
|
||||
for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) {
|
||||
for (int i = firstRound; i < end; i += checkpointInterval) {
|
||||
arr.add(i);
|
||||
}
|
||||
arr.add(numOfRounds);
|
||||
return arr;
|
||||
} else if (checkpointInterval <= 0) {
|
||||
List<Integer> l = new ArrayList<Integer>();
|
||||
l.add(numOfRounds);
|
||||
return l;
|
||||
} else {
|
||||
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
|
||||
}
|
||||
|
||||
if (!arr.contains(lastRound)) {
|
||||
arr.add(lastRound);
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014,2021 by Contributors
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -133,7 +133,7 @@ public class XGBoost {
|
||||
int earlyStoppingRound) throws XGBoostError {
|
||||
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
|
||||
}
|
||||
|
||||
// save checkpoint if iter is in checkpointIterations
|
||||
private static void saveCheckpoint(
|
||||
Booster booster,
|
||||
int iter,
|
||||
@@ -169,7 +169,6 @@ public class XGBoost {
|
||||
int bestIteration;
|
||||
List<String> names = new ArrayList<String>();
|
||||
List<DMatrix> mats = new ArrayList<DMatrix>();
|
||||
Set<Integer> checkpointIterations = new HashSet<>();
|
||||
ExternalCheckpointManager ecm = null;
|
||||
if (checkpointPath != null) {
|
||||
ecm = new ExternalCheckpointManager(checkpointPath, fs);
|
||||
@@ -203,32 +202,30 @@ public class XGBoost {
|
||||
booster = new Booster(params, allMats);
|
||||
booster.setFeatureNames(dtrain.getFeatureNames());
|
||||
booster.setFeatureTypes(dtrain.getFeatureTypes());
|
||||
booster.loadRabitCheckpoint();
|
||||
} else {
|
||||
// Start training on an existing booster
|
||||
booster.setParams(params);
|
||||
}
|
||||
|
||||
Set<Integer> checkpointIterations = new HashSet<>();
|
||||
if (ecm != null) {
|
||||
checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
|
||||
checkpointIterations = new HashSet<>(
|
||||
ecm.getCheckpointRounds(booster.getNumBoostedRound(), checkpointInterval, numRounds));
|
||||
}
|
||||
|
||||
boolean initial_best_score_flag = false;
|
||||
boolean max_direction = false;
|
||||
|
||||
// begin to train
|
||||
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
|
||||
if (booster.getVersion() % 2 == 0) {
|
||||
if (obj != null) {
|
||||
booster.update(dtrain, obj);
|
||||
} else {
|
||||
booster.update(dtrain, iter);
|
||||
}
|
||||
saveCheckpoint(booster, iter, checkpointIterations, ecm);
|
||||
booster.saveRabitCheckpoint();
|
||||
for (int iter = 0; iter < numRounds; iter++) {
|
||||
if (obj != null) {
|
||||
booster.update(dtrain, iter, obj);
|
||||
} else {
|
||||
booster.update(dtrain, iter);
|
||||
}
|
||||
saveCheckpoint(booster, iter, checkpointIterations, ecm);
|
||||
|
||||
//evaluation
|
||||
// evaluation
|
||||
if (evalMats.length > 0) {
|
||||
float[] metricsOut = new float[evalMats.length];
|
||||
String evalInfo;
|
||||
@@ -285,7 +282,6 @@ public class XGBoost {
|
||||
Communicator.communicatorPrint(evalInfo + '\n');
|
||||
}
|
||||
}
|
||||
booster.saveRabitCheckpoint();
|
||||
}
|
||||
return booster;
|
||||
}
|
||||
|
||||
@@ -140,10 +140,11 @@ class XGBoostJNI {
|
||||
public final static native int XGBoosterGetAttrNames(long handle, String[][] out_strings);
|
||||
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
|
||||
public final static native int XGBoosterSetAttr(long handle, String key, String value);
|
||||
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
|
||||
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
|
||||
|
||||
public final static native int XGBoosterGetNumFeature(long handle, long[] feature);
|
||||
|
||||
public final static native int XGBoosterGetNumBoostedRound(long handle, int[] rounds);
|
||||
|
||||
// communicator functions
|
||||
public final static native int CommunicatorInit(String[] args);
|
||||
public final static native int CommunicatorFinalize();
|
||||
|
||||
@@ -326,7 +326,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
@throws(classOf[XGBoostError])
|
||||
def getNumFeature: Long = booster.getNumFeature
|
||||
|
||||
def getVersion: Int = booster.getVersion
|
||||
def getNumBoostedRound: Long = booster.getNumBoostedRound
|
||||
|
||||
/**
|
||||
* Save model into a raw byte array. Available options are "json", "ubj" and "deprecated".
|
||||
|
||||
@@ -984,33 +984,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterLoadRabitCheckpoint
|
||||
* Signature: (J[I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint
|
||||
(JNIEnv *jenv , jclass jcls, jlong jhandle, jintArray jout) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
int version;
|
||||
int ret = XGBoosterLoadRabitCheckpoint(handle, &version);
|
||||
JVM_CHECK_CALL(ret);
|
||||
jint jversion = version;
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &jversion);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSaveRabitCheckpoint
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
return XGBoosterSaveRabitCheckpoint(handle);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetNumFeature
|
||||
@@ -1027,6 +1000,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFea
|
||||
return ret;
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoostedRound(
|
||||
JNIEnv *jenv, jclass, jlong jhandle, jintArray jout) {
|
||||
BoosterHandle handle = (BoosterHandle)jhandle;
|
||||
std::int32_t n_rounds{0};
|
||||
auto ret = XGBoosterBoostedRounds(handle, &n_rounds);
|
||||
JVM_CHECK_CALL(ret);
|
||||
jint jn_rounds = n_rounds;
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &jn_rounds);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorInit
|
||||
|
||||
@@ -287,22 +287,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr
|
||||
(JNIEnv *, jclass, jlong, jstring, jstring);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterLoadRabitCheckpoint
|
||||
* Signature: (J[I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSaveRabitCheckpoint
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetNumFeature
|
||||
@@ -311,6 +295,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabit
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature
|
||||
(JNIEnv *, jclass, jlong, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetNumBoostedRound
|
||||
* Signature: (J[I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoostedRound
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorInit
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -16,7 +16,6 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
@@ -31,7 +30,7 @@ import static org.junit.Assert.fail;
|
||||
|
||||
/**
|
||||
* test cases for Booster Inplace Predict
|
||||
*
|
||||
*
|
||||
* @author hzx and Sovrn
|
||||
*/
|
||||
public class BoosterImplTest {
|
||||
@@ -845,14 +844,12 @@ public class BoosterImplTest {
|
||||
float tempBoosterError = eval.eval(tempBooster.predict(testMat, true, 0), testMat);
|
||||
|
||||
// Save tempBooster to bytestream and load back
|
||||
int prevVersion = tempBooster.getVersion();
|
||||
ByteArrayInputStream in = new ByteArrayInputStream(tempBooster.toByteArray());
|
||||
tempBooster = XGBoost.loadModel(in);
|
||||
in.close();
|
||||
tempBooster.setVersion(prevVersion);
|
||||
|
||||
// Continue training using tempBooster
|
||||
round = 4;
|
||||
round = 2;
|
||||
Booster booster2 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0, tempBooster);
|
||||
float booster2error = eval.eval(booster2.predict(testMat, true, 0), testMat);
|
||||
TestCase.assertTrue(booster1error == booster2error);
|
||||
|
||||
Reference in New Issue
Block a user