enable ROCm on latest XGBoost
This commit is contained in:
@@ -4,16 +4,16 @@ list(APPEND JVM_SOURCES
|
||||
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j/src/native/xgboost4j.cpp
|
||||
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cpp)
|
||||
|
||||
if (USE_CUDA)
|
||||
if(USE_CUDA)
|
||||
list(APPEND JVM_SOURCES
|
||||
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu)
|
||||
endif (USE_CUDA)
|
||||
endif()
|
||||
|
||||
add_library(xgboost4j SHARED ${JVM_SOURCES} ${XGBOOST_OBJ_SOURCES})
|
||||
|
||||
if (ENABLE_ALL_WARNINGS)
|
||||
if(ENABLE_ALL_WARNINGS)
|
||||
target_compile_options(xgboost4j PUBLIC -Wall -Wextra)
|
||||
endif (ENABLE_ALL_WARNINGS)
|
||||
endif()
|
||||
|
||||
target_link_libraries(xgboost4j PRIVATE objxgboost)
|
||||
target_include_directories(xgboost4j
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -32,57 +32,53 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
|
||||
}
|
||||
|
||||
private def createNewModels():
|
||||
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
|
||||
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
|
||||
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
||||
val (model4, model8) = {
|
||||
val (model2, model4) = {
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val paramMap = produceParamMap(tmpPath, 2)
|
||||
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
|
||||
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
|
||||
}
|
||||
(tmpPath, model4, model8)
|
||||
(tmpPath, model2, model4)
|
||||
}
|
||||
|
||||
test("test update/load models") {
|
||||
val (tmpPath, model4, model8) = createNewModels()
|
||||
val (tmpPath, model2, model4) = createNewModels()
|
||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||
|
||||
manager.updateCheckpoint(model4._booster.booster)
|
||||
manager.updateCheckpoint(model2._booster.booster)
|
||||
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "4.model")
|
||||
assert(manager.loadCheckpointAsScalaBooster().getVersion == 4)
|
||||
assert(files.head.getPath.getName == "1.model")
|
||||
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 2)
|
||||
|
||||
manager.updateCheckpoint(model8._booster)
|
||||
manager.updateCheckpoint(model4._booster)
|
||||
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "8.model")
|
||||
assert(manager.loadCheckpointAsScalaBooster().getVersion == 8)
|
||||
assert(files.head.getPath.getName == "3.model")
|
||||
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4)
|
||||
}
|
||||
|
||||
test("test cleanUpHigherVersions") {
|
||||
val (tmpPath, model4, model8) = createNewModels()
|
||||
val (tmpPath, model2, model4) = createNewModels()
|
||||
|
||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||
manager.updateCheckpoint(model8._booster)
|
||||
manager.cleanUpHigherVersions(8)
|
||||
assert(new File(s"$tmpPath/8.model").exists())
|
||||
manager.updateCheckpoint(model4._booster)
|
||||
manager.cleanUpHigherVersions(3)
|
||||
assert(new File(s"$tmpPath/3.model").exists())
|
||||
|
||||
manager.cleanUpHigherVersions(4)
|
||||
assert(!new File(s"$tmpPath/8.model").exists())
|
||||
manager.cleanUpHigherVersions(2)
|
||||
assert(!new File(s"$tmpPath/3.model").exists())
|
||||
}
|
||||
|
||||
test("test checkpoint rounds") {
|
||||
import scala.collection.JavaConverters._
|
||||
val (tmpPath, model4, model8) = createNewModels()
|
||||
val (tmpPath, model2, model4) = createNewModels()
|
||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||
assertResult(Seq(7))(
|
||||
manager.getCheckpointRounds(0, 7).asScala)
|
||||
assertResult(Seq(2, 4, 6, 7))(
|
||||
manager.getCheckpointRounds(2, 7).asScala)
|
||||
manager.updateCheckpoint(model4._booster)
|
||||
assertResult(Seq(4, 6, 7))(
|
||||
manager.getCheckpointRounds(2, 7).asScala)
|
||||
assertResult(Seq(2))(manager.getCheckpointRounds(0, 0, 3).asScala)
|
||||
assertResult(Seq(0, 2, 4, 6))(manager.getCheckpointRounds(0, 2, 7).asScala)
|
||||
assertResult(Seq(0, 2, 4, 6, 7))(manager.getCheckpointRounds(0, 2, 8).asScala)
|
||||
}
|
||||
|
||||
|
||||
@@ -109,8 +105,8 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
|
||||
// Check only one model is kept after training
|
||||
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "8.model")
|
||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
|
||||
assert(files.head.getPath.getName == "4.model")
|
||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.model")
|
||||
// Train next model based on prev model
|
||||
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
||||
assert(error(tmpModel) >= error(prevModel._booster))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -39,6 +39,21 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
// handle to the booster.
|
||||
private long handle = 0;
|
||||
private int version = 0;
|
||||
/**
|
||||
* Type of prediction, used for inplace_predict.
|
||||
*/
|
||||
public enum PredictionType {
|
||||
kValue(0),
|
||||
kMargin(1);
|
||||
|
||||
private Integer ptype;
|
||||
private PredictionType(final Integer ptype) {
|
||||
this.ptype = ptype;
|
||||
}
|
||||
public Integer getPType() {
|
||||
return ptype;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new Booster with empty stage.
|
||||
@@ -375,6 +390,97 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
return predicts;
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform thread-safe prediction.
|
||||
*
|
||||
* @param data Flattened input matrix of features for prediction
|
||||
* @param nrow The number of preditions to make (count of input matrix rows)
|
||||
* @param ncol The number of features in the model (count of input matrix columns)
|
||||
* @param missing Value indicating missing element in the <code>data</code> input matrix
|
||||
*
|
||||
* @return predict Result matrix
|
||||
*/
|
||||
public float[][] inplace_predict(float[] data,
|
||||
int nrow,
|
||||
int ncol,
|
||||
float missing) throws XGBoostError {
|
||||
int[] iteration_range = new int[2];
|
||||
iteration_range[0] = 0;
|
||||
iteration_range[1] = 0;
|
||||
return this.inplace_predict(data, nrow, ncol,
|
||||
missing, iteration_range, PredictionType.kValue, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform thread-safe prediction.
|
||||
*
|
||||
* @param data Flattened input matrix of features for prediction
|
||||
* @param nrow The number of preditions to make (count of input matrix rows)
|
||||
* @param ncol The number of features in the model (count of input matrix columns)
|
||||
* @param missing Value indicating missing element in the <code>data</code> input matrix
|
||||
* @param iteration_range Specifies which layer of trees are used in prediction. For
|
||||
* example, if a random forest is trained with 100 rounds.
|
||||
* Specifying `iteration_range=[10, 20)`, then only the forests
|
||||
* built during [10, 20) (half open set) rounds are used in this
|
||||
* prediction.
|
||||
*
|
||||
* @return predict Result matrix
|
||||
*/
|
||||
public float[][] inplace_predict(float[] data,
|
||||
int nrow,
|
||||
int ncol,
|
||||
float missing, int[] iteration_range) throws XGBoostError {
|
||||
return this.inplace_predict(data, nrow, ncol,
|
||||
missing, iteration_range, PredictionType.kValue, null);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Perform thread-safe prediction.
|
||||
*
|
||||
* @param data Flattened input matrix of features for prediction
|
||||
* @param nrow The number of preditions to make (count of input matrix rows)
|
||||
* @param ncol The number of features in the model (count of input matrix columns)
|
||||
* @param missing Value indicating missing element in the <code>data</code> input matrix
|
||||
* @param iteration_range Specifies which layer of trees are used in prediction. For
|
||||
* example, if a random forest is trained with 100 rounds.
|
||||
* Specifying `iteration_range=[10, 20)`, then only the forests
|
||||
* built during [10, 20) (half open set) rounds are used in this
|
||||
* prediction.
|
||||
* @param predict_type What kind of prediction to run.
|
||||
* @return predict Result matrix
|
||||
*/
|
||||
public float[][] inplace_predict(float[] data,
|
||||
int nrow,
|
||||
int ncol,
|
||||
float missing,
|
||||
int[] iteration_range,
|
||||
PredictionType predict_type,
|
||||
float[] base_margin) throws XGBoostError {
|
||||
if (iteration_range.length != 2) {
|
||||
throw new XGBoostError(new String("Iteration range is expected to be [begin, end)."));
|
||||
}
|
||||
int ptype = predict_type.getPType();
|
||||
|
||||
int begin = iteration_range[0];
|
||||
int end = iteration_range[1];
|
||||
|
||||
float[][] rawPredicts = new float[1][];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterPredictFromDense(handle, data, nrow, ncol,
|
||||
missing,
|
||||
begin, end, ptype, base_margin, rawPredicts));
|
||||
|
||||
int col = rawPredicts[0].length / nrow;
|
||||
float[][] predicts = new float[nrow][col];
|
||||
int r, c;
|
||||
for (int i = 0; i < rawPredicts[0].length; i++) {
|
||||
r = i / col;
|
||||
c = i % col;
|
||||
predicts[r][c] = rawPredicts[0][i];
|
||||
}
|
||||
return predicts;
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict leaf indices given the data
|
||||
*
|
||||
@@ -681,35 +787,6 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
return importanceMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the model as byte array representation.
|
||||
* Write these bytes to a file will give compatible format with other xgboost bindings.
|
||||
*
|
||||
* If java natively support HDFS file API, use toByteArray and write the ByteArray
|
||||
*
|
||||
* @param withStats Controls whether the split statistics are output.
|
||||
* @return dumped model information
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
private String[] getDumpInfo(boolean withStats) throws XGBoostError {
|
||||
int statsFlag = 0;
|
||||
if (withStats) {
|
||||
statsFlag = 1;
|
||||
}
|
||||
String[][] modelInfos = new String[1][];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelEx(handle, "", statsFlag, "text",
|
||||
modelInfos));
|
||||
return modelInfos[0];
|
||||
}
|
||||
|
||||
public int getVersion() {
|
||||
return this.version;
|
||||
}
|
||||
|
||||
public void setVersion(int version) {
|
||||
this.version = version;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save model into raw byte array. Currently it's using the deprecated format as
|
||||
* default, which will be changed into `ubj` in future releases.
|
||||
@@ -735,29 +812,6 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
return bytes[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the booster model from thread-local rabit checkpoint.
|
||||
* This is only used in distributed training.
|
||||
* @return the stored version number of the checkpoint.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
int loadRabitCheckpoint() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
|
||||
version = out[0];
|
||||
return version;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the booster model into thread-local rabit checkpoint and increment the version.
|
||||
* This is only used in distributed training.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
void saveRabitCheckpoint() throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
|
||||
version += 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get number of model features.
|
||||
* @return the number of features.
|
||||
@@ -768,6 +822,11 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, numFeature));
|
||||
return numFeature[0];
|
||||
}
|
||||
public int getNumBoostedRound() throws XGBoostError {
|
||||
int[] numRound = new int[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumBoostedRound(this.handle, numRound));
|
||||
return numRound[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal initialization function.
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.IOException;
|
||||
@@ -15,7 +30,7 @@ public class ExternalCheckpointManager {
|
||||
|
||||
private Log logger = LogFactory.getLog("ExternalCheckpointManager");
|
||||
private String modelSuffix = ".model";
|
||||
private Path checkpointPath;
|
||||
private Path checkpointPath; // directory for checkpoints
|
||||
private FileSystem fs;
|
||||
|
||||
public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
|
||||
@@ -35,6 +50,7 @@ public class ExternalCheckpointManager {
|
||||
if (!fs.exists(checkpointPath)) {
|
||||
return new ArrayList<>();
|
||||
} else {
|
||||
// Get integer versions from a list of checkpoint files.
|
||||
return Arrays.stream(fs.listStatus(checkpointPath))
|
||||
.map(path -> path.getPath().getName())
|
||||
.filter(fileName -> fileName.endsWith(modelSuffix))
|
||||
@@ -44,6 +60,11 @@ public class ExternalCheckpointManager {
|
||||
}
|
||||
}
|
||||
|
||||
private Integer latest(List<Integer> versions) {
|
||||
return versions.stream()
|
||||
.max(Comparator.comparing(Integer::valueOf)).get();
|
||||
}
|
||||
|
||||
public void cleanPath() throws IOException {
|
||||
fs.delete(checkpointPath, true);
|
||||
}
|
||||
@@ -51,12 +72,11 @@ public class ExternalCheckpointManager {
|
||||
public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
|
||||
List<Integer> versions = getExistingVersions();
|
||||
if (versions.size() > 0) {
|
||||
int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
|
||||
int latestVersion = this.latest(versions);
|
||||
String checkpointPath = getPath(latestVersion);
|
||||
InputStream in = fs.open(new Path(checkpointPath));
|
||||
logger.info("loaded checkpoint from " + checkpointPath);
|
||||
Booster booster = XGBoost.loadModel(in);
|
||||
booster.setVersion(latestVersion);
|
||||
return booster;
|
||||
} else {
|
||||
return null;
|
||||
@@ -65,13 +85,16 @@ public class ExternalCheckpointManager {
|
||||
|
||||
public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
|
||||
List<String> prevModelPaths = getExistingVersions().stream()
|
||||
.map(this::getPath).collect(Collectors.toList());
|
||||
String eventualPath = getPath(boosterToCheckpoint.getVersion());
|
||||
.map(this::getPath).collect(Collectors.toList());
|
||||
// checkpointing is done after update, so n_rounds - 1 is the current iteration
|
||||
// accounting for training continuation.
|
||||
Integer iter = boosterToCheckpoint.getNumBoostedRound() - 1;
|
||||
String eventualPath = getPath(iter);
|
||||
String tempPath = eventualPath + "-" + UUID.randomUUID();
|
||||
try (OutputStream out = fs.create(new Path(tempPath), true)) {
|
||||
boosterToCheckpoint.saveModel(out);
|
||||
fs.rename(new Path(tempPath), new Path(eventualPath));
|
||||
logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
|
||||
logger.info("saving checkpoint with version " + iter);
|
||||
prevModelPaths.stream().forEach(path -> {
|
||||
try {
|
||||
fs.delete(new Path(path), true);
|
||||
@@ -83,7 +106,7 @@ public class ExternalCheckpointManager {
|
||||
}
|
||||
|
||||
public void cleanUpHigherVersions(int currentRound) throws IOException {
|
||||
getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
|
||||
getExistingVersions().stream().filter(v -> v > currentRound).forEach(v -> {
|
||||
try {
|
||||
fs.delete(new Path(getPath(v)), true);
|
||||
} catch (IOException e) {
|
||||
@@ -91,27 +114,26 @@ public class ExternalCheckpointManager {
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds)
|
||||
// Get a list of iterations that need checkpointing.
|
||||
public List<Integer> getCheckpointRounds(
|
||||
int firstRound, int checkpointInterval, int numOfRounds)
|
||||
throws IOException {
|
||||
int end = firstRound + numOfRounds; // exclusive
|
||||
int lastRound = end - 1;
|
||||
if (end - 1 < 0) {
|
||||
throw new IllegalArgumentException("Inavlid `numOfRounds`.");
|
||||
}
|
||||
|
||||
List<Integer> arr = new ArrayList<>();
|
||||
if (checkpointInterval > 0) {
|
||||
List<Integer> prevRounds =
|
||||
getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
|
||||
prevRounds.add(0);
|
||||
int firstCheckpointRound = prevRounds.stream()
|
||||
.max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
|
||||
List<Integer> arr = new ArrayList<>();
|
||||
for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) {
|
||||
for (int i = firstRound; i < end; i += checkpointInterval) {
|
||||
arr.add(i);
|
||||
}
|
||||
arr.add(numOfRounds);
|
||||
return arr;
|
||||
} else if (checkpointInterval <= 0) {
|
||||
List<Integer> l = new ArrayList<Integer>();
|
||||
l.add(numOfRounds);
|
||||
return l;
|
||||
} else {
|
||||
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
|
||||
}
|
||||
|
||||
if (!arr.contains(lastRound)) {
|
||||
arr.add(lastRound);
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014,2021 by Contributors
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -133,7 +133,7 @@ public class XGBoost {
|
||||
int earlyStoppingRound) throws XGBoostError {
|
||||
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
|
||||
}
|
||||
|
||||
// save checkpoint if iter is in checkpointIterations
|
||||
private static void saveCheckpoint(
|
||||
Booster booster,
|
||||
int iter,
|
||||
@@ -169,7 +169,6 @@ public class XGBoost {
|
||||
int bestIteration;
|
||||
List<String> names = new ArrayList<String>();
|
||||
List<DMatrix> mats = new ArrayList<DMatrix>();
|
||||
Set<Integer> checkpointIterations = new HashSet<>();
|
||||
ExternalCheckpointManager ecm = null;
|
||||
if (checkpointPath != null) {
|
||||
ecm = new ExternalCheckpointManager(checkpointPath, fs);
|
||||
@@ -203,32 +202,30 @@ public class XGBoost {
|
||||
booster = new Booster(params, allMats);
|
||||
booster.setFeatureNames(dtrain.getFeatureNames());
|
||||
booster.setFeatureTypes(dtrain.getFeatureTypes());
|
||||
booster.loadRabitCheckpoint();
|
||||
} else {
|
||||
// Start training on an existing booster
|
||||
booster.setParams(params);
|
||||
}
|
||||
|
||||
Set<Integer> checkpointIterations = new HashSet<>();
|
||||
if (ecm != null) {
|
||||
checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
|
||||
checkpointIterations = new HashSet<>(
|
||||
ecm.getCheckpointRounds(booster.getNumBoostedRound(), checkpointInterval, numRounds));
|
||||
}
|
||||
|
||||
boolean initial_best_score_flag = false;
|
||||
boolean max_direction = false;
|
||||
|
||||
// begin to train
|
||||
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
|
||||
if (booster.getVersion() % 2 == 0) {
|
||||
if (obj != null) {
|
||||
booster.update(dtrain, obj);
|
||||
} else {
|
||||
booster.update(dtrain, iter);
|
||||
}
|
||||
saveCheckpoint(booster, iter, checkpointIterations, ecm);
|
||||
booster.saveRabitCheckpoint();
|
||||
for (int iter = 0; iter < numRounds; iter++) {
|
||||
if (obj != null) {
|
||||
booster.update(dtrain, iter, obj);
|
||||
} else {
|
||||
booster.update(dtrain, iter);
|
||||
}
|
||||
saveCheckpoint(booster, iter, checkpointIterations, ecm);
|
||||
|
||||
//evaluation
|
||||
// evaluation
|
||||
if (evalMats.length > 0) {
|
||||
float[] metricsOut = new float[evalMats.length];
|
||||
String evalInfo;
|
||||
@@ -285,7 +282,6 @@ public class XGBoost {
|
||||
Communicator.communicatorPrint(evalInfo + '\n');
|
||||
}
|
||||
}
|
||||
booster.saveRabitCheckpoint();
|
||||
}
|
||||
return booster;
|
||||
}
|
||||
|
||||
@@ -119,6 +119,10 @@ class XGBoostJNI {
|
||||
public final static native int XGBoosterPredict(long handle, long dmat, int option_mask,
|
||||
int ntree_limit, float[][] predicts);
|
||||
|
||||
public final static native int XGBoosterPredictFromDense(long handle, float[] data,
|
||||
long nrow, long ncol, float missing, int iteration_begin, int iteration_end, int predict_type, float[] margin,
|
||||
float[][] predicts);
|
||||
|
||||
public final static native int XGBoosterLoadModel(long handle, String fname);
|
||||
|
||||
public final static native int XGBoosterSaveModel(long handle, String fname);
|
||||
@@ -136,10 +140,11 @@ class XGBoostJNI {
|
||||
public final static native int XGBoosterGetAttrNames(long handle, String[][] out_strings);
|
||||
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
|
||||
public final static native int XGBoosterSetAttr(long handle, String key, String value);
|
||||
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
|
||||
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
|
||||
|
||||
public final static native int XGBoosterGetNumFeature(long handle, long[] feature);
|
||||
|
||||
public final static native int XGBoosterGetNumBoostedRound(long handle, int[] rounds);
|
||||
|
||||
// communicator functions
|
||||
public final static native int CommunicatorInit(String[] args);
|
||||
public final static native int CommunicatorFinalize();
|
||||
@@ -154,10 +159,6 @@ class XGBoostJNI {
|
||||
public final static native int XGDMatrixSetInfoFromInterface(
|
||||
long handle, String field, String json);
|
||||
|
||||
@Deprecated
|
||||
public final static native int XGDeviceQuantileDMatrixCreateFromCallback(
|
||||
java.util.Iterator<ColumnBatch> iter, float missing, int nthread, int maxBin, long[] out);
|
||||
|
||||
public final static native int XGQuantileDMatrixCreateFromCallback(
|
||||
java.util.Iterator<ColumnBatch> iter, java.util.Iterator<ColumnBatch> ref, String config, long[] out);
|
||||
|
||||
|
||||
@@ -326,7 +326,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
@throws(classOf[XGBoostError])
|
||||
def getNumFeature: Long = booster.getNumFeature
|
||||
|
||||
def getVersion: Int = booster.getVersion
|
||||
def getNumBoostedRound: Long = booster.getNumBoostedRound
|
||||
|
||||
/**
|
||||
* Save model into a raw byte array. Available options are "json", "ubj" and "deprecated".
|
||||
|
||||
@@ -684,6 +684,85 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterPredictFromDense
|
||||
* Signature: (J[FJJFIII[F[[F)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFromDense(
|
||||
JNIEnv *jenv, jclass jcls, jlong jhandle, jfloatArray jdata, jlong num_rows, jlong num_features,
|
||||
jfloat missing, jint iteration_begin, jint iteration_end, jint predict_type,
|
||||
jfloatArray jmargin, jobjectArray jout) {
|
||||
API_BEGIN();
|
||||
BoosterHandle handle = reinterpret_cast<BoosterHandle>(jhandle);
|
||||
|
||||
/**
|
||||
* Create array interface.
|
||||
*/
|
||||
namespace linalg = xgboost::linalg;
|
||||
jfloat *data = jenv->GetFloatArrayElements(jdata, nullptr);
|
||||
xgboost::Context ctx;
|
||||
auto t_data = linalg::MakeTensorView(
|
||||
ctx.Device(),
|
||||
xgboost::common::Span{data, static_cast<std::size_t>(num_rows * num_features)}, num_rows,
|
||||
num_features);
|
||||
auto s_array = linalg::ArrayInterfaceStr(t_data);
|
||||
|
||||
/**
|
||||
* Create configuration object.
|
||||
*/
|
||||
xgboost::Json config{xgboost::Object{}};
|
||||
config["cache_id"] = xgboost::Integer{};
|
||||
config["type"] = xgboost::Integer{static_cast<std::int32_t>(predict_type)};
|
||||
config["iteration_begin"] = xgboost::Integer{static_cast<xgboost::bst_layer_t>(iteration_begin)};
|
||||
config["iteration_end"] = xgboost::Integer{static_cast<xgboost::bst_layer_t>(iteration_end)};
|
||||
config["missing"] = xgboost::Number{static_cast<float>(missing)};
|
||||
config["strict_shape"] = xgboost::Boolean{true};
|
||||
std::string s_config;
|
||||
xgboost::Json::Dump(config, &s_config);
|
||||
|
||||
/**
|
||||
* Handle base margin
|
||||
*/
|
||||
BoosterHandle proxy{nullptr};
|
||||
|
||||
float *margin{nullptr};
|
||||
if (jmargin) {
|
||||
margin = jenv->GetFloatArrayElements(jmargin, nullptr);
|
||||
JVM_CHECK_CALL(XGProxyDMatrixCreate(&proxy));
|
||||
JVM_CHECK_CALL(
|
||||
XGDMatrixSetFloatInfo(proxy, "base_margin", margin, jenv->GetArrayLength(jmargin)));
|
||||
}
|
||||
|
||||
bst_ulong const *out_shape;
|
||||
bst_ulong out_dim;
|
||||
float const *result;
|
||||
auto ret = XGBoosterPredictFromDense(handle, s_array.c_str(), s_config.c_str(), proxy, &out_shape,
|
||||
&out_dim, &result);
|
||||
|
||||
jenv->ReleaseFloatArrayElements(jdata, data, 0);
|
||||
if (proxy) {
|
||||
XGDMatrixFree(proxy);
|
||||
jenv->ReleaseFloatArrayElements(jmargin, margin, 0);
|
||||
}
|
||||
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::size_t n{1};
|
||||
for (std::size_t i = 0; i < out_dim; ++i) {
|
||||
n *= out_shape[i];
|
||||
}
|
||||
|
||||
jfloatArray jarray = jenv->NewFloatArray(n);
|
||||
|
||||
jenv->SetFloatArrayRegion(jarray, 0, n, result);
|
||||
jenv->SetObjectArrayElement(jout, 0, jarray);
|
||||
|
||||
API_END();
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterLoadModel
|
||||
@@ -905,33 +984,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterLoadRabitCheckpoint
|
||||
* Signature: (J[I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint
|
||||
(JNIEnv *jenv , jclass jcls, jlong jhandle, jintArray jout) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
int version;
|
||||
int ret = XGBoosterLoadRabitCheckpoint(handle, &version);
|
||||
JVM_CHECK_CALL(ret);
|
||||
jint jversion = version;
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &jversion);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSaveRabitCheckpoint
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
return XGBoosterSaveRabitCheckpoint(handle);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetNumFeature
|
||||
@@ -948,6 +1000,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFea
|
||||
return ret;
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoostedRound(
|
||||
JNIEnv *jenv, jclass, jlong jhandle, jintArray jout) {
|
||||
BoosterHandle handle = (BoosterHandle)jhandle;
|
||||
std::int32_t n_rounds{0};
|
||||
auto ret = XGBoosterBoostedRounds(handle, &n_rounds);
|
||||
JVM_CHECK_CALL(ret);
|
||||
jint jn_rounds = n_rounds;
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &jn_rounds);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorInit
|
||||
|
||||
@@ -207,6 +207,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterEvalOneIt
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
|
||||
(JNIEnv *, jclass, jlong, jlong, jint, jint, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterPredictFromDense
|
||||
* Signature: (J[FJJFIII[F[[F)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFromDense
|
||||
(JNIEnv *, jclass, jlong, jfloatArray, jlong, jlong, jfloat, jint, jint, jint, jfloatArray, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterLoadModel
|
||||
@@ -279,22 +287,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr
|
||||
(JNIEnv *, jclass, jlong, jstring, jstring);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterLoadRabitCheckpoint
|
||||
* Signature: (J[I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSaveRabitCheckpoint
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetNumFeature
|
||||
@@ -303,6 +295,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabit
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature
|
||||
(JNIEnv *, jclass, jlong, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetNumBoostedRound
|
||||
* Signature: (J[I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoostedRound
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: CommunicatorInit
|
||||
@@ -359,14 +359,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorAllred
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface
|
||||
(JNIEnv *, jclass, jlong, jstring, jstring);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDeviceQuantileDMatrixCreateFromCallback
|
||||
* Signature: (Ljava/util/Iterator;FII[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDMatrixCreateFromCallback
|
||||
(JNIEnv *, jclass, jobject, jfloat, jint, jint, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGQuantileDMatrixCreateFromCallback
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -15,16 +15,23 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.*;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.*;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
/**
|
||||
* test cases for Booster
|
||||
* test cases for Booster Inplace Predict
|
||||
*
|
||||
* @author hzx
|
||||
* @author hzx and Sovrn
|
||||
*/
|
||||
public class BoosterImplTest {
|
||||
private String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1&format=libsvm";
|
||||
@@ -99,6 +106,179 @@ public class BoosterImplTest {
|
||||
TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inplacePredictTest() throws XGBoostError {
|
||||
/* Data Generation */
|
||||
// Generate a training set.
|
||||
int trainRows = 1000;
|
||||
int features = 10;
|
||||
int trainSize = trainRows * features;
|
||||
float[] trainX = generateRandomDataSet(trainSize);
|
||||
float[] trainY = generateRandomDataSet(trainRows);
|
||||
|
||||
DMatrix trainingMatrix = new DMatrix(trainX, trainRows, features, Float.NaN);
|
||||
trainingMatrix.setLabel(trainY);
|
||||
|
||||
// Generate a testing set
|
||||
int testRows = 10;
|
||||
int testSize = testRows * features;
|
||||
float[] testX = generateRandomDataSet(testSize);
|
||||
float[] testY = generateRandomDataSet(testRows);
|
||||
|
||||
DMatrix testingMatrix = new DMatrix(testX, testRows, features, Float.NaN);
|
||||
testingMatrix.setLabel(testY);
|
||||
|
||||
/* Training */
|
||||
|
||||
// Set parameters
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth",2);
|
||||
params.put("silent", 1);
|
||||
params.put("tree_method", "hist");
|
||||
|
||||
Map<String, DMatrix> watches = new HashMap<>();
|
||||
watches.put("train", trainingMatrix);
|
||||
watches.put("test", testingMatrix);
|
||||
|
||||
Booster booster = XGBoost.train(trainingMatrix, params, 10, watches, null, null);
|
||||
|
||||
/* Prediction */
|
||||
|
||||
// Standard prediction
|
||||
float[][] predictions = booster.predict(testingMatrix);
|
||||
|
||||
// Inplace-prediction
|
||||
float[][] inplacePredictions = booster.inplace_predict(testX, testRows, features, Float.NaN);
|
||||
|
||||
// Confirm that the two prediction results are identical
|
||||
assertArrayEquals(predictions, inplacePredictions);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inplacePredictMultiPredictTest() throws InterruptedException {
|
||||
// Multithreaded, multiple prediction
|
||||
int trainRows = 1000;
|
||||
int features = 10;
|
||||
int trainSize = trainRows * features;
|
||||
|
||||
int testRows = 10;
|
||||
int testSize = testRows * features;
|
||||
|
||||
//Simulate multiple predictions on multiple random data sets simultaneously.
|
||||
ExecutorService executorService = Executors.newFixedThreadPool(5);
|
||||
int predictsToPerform = 100;
|
||||
for(int i = 0; i < predictsToPerform; i++) {
|
||||
executorService.submit(() -> {
|
||||
try {
|
||||
float[] trainX = generateRandomDataSet(trainSize);
|
||||
float[] trainY = generateRandomDataSet(trainRows);
|
||||
DMatrix trainingMatrix = new DMatrix(trainX, trainRows, features, Float.NaN);
|
||||
trainingMatrix.setLabel(trainY);
|
||||
|
||||
float[] testX = generateRandomDataSet(testSize);
|
||||
float[] testY = generateRandomDataSet(testRows);
|
||||
DMatrix testingMatrix = new DMatrix(testX, testRows, features, Float.NaN);
|
||||
testingMatrix.setLabel(testY);
|
||||
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
params.put("tree_method", "hist");
|
||||
|
||||
Map<String, DMatrix> watches = new HashMap<>();
|
||||
watches.put("train", trainingMatrix);
|
||||
watches.put("test", testingMatrix);
|
||||
|
||||
Booster booster = XGBoost.train(trainingMatrix, params, 10, watches, null, null);
|
||||
|
||||
float[][] predictions = booster.predict(testingMatrix);
|
||||
float[][] inplacePredictions = booster.inplace_predict(testX, testRows, features, Float.NaN);
|
||||
|
||||
assertArrayEquals(predictions, inplacePredictions);
|
||||
} catch (XGBoostError xgBoostError) {
|
||||
fail(xgBoostError.getMessage());
|
||||
}
|
||||
});
|
||||
}
|
||||
executorService.shutdown();
|
||||
if(!executorService.awaitTermination(1, TimeUnit.MINUTES))
|
||||
executorService.shutdownNow();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inplacePredictWithMarginTest() throws XGBoostError {
|
||||
//Generate a training set
|
||||
int trainRows = 1000;
|
||||
int features = 10;
|
||||
int trainSize = trainRows * features;
|
||||
float[] trainX = generateRandomDataSet(trainSize);
|
||||
float[] trainY = generateRandomDataSet(trainRows);
|
||||
|
||||
DMatrix trainingMatrix = new DMatrix(trainX, trainRows, features, Float.NaN);
|
||||
trainingMatrix.setLabel(trainY);
|
||||
|
||||
// Generate a testing set
|
||||
int testRows = 10;
|
||||
int testSize = testRows * features;
|
||||
float[] testX = generateRandomDataSet(testSize);
|
||||
float[] testY = generateRandomDataSet(testRows);
|
||||
|
||||
DMatrix testingMatrix = new DMatrix(testX, testRows, features, Float.NaN);
|
||||
testingMatrix.setLabel(testY);
|
||||
|
||||
// Set booster parameters
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth",2);
|
||||
params.put("tree_method", "hist");
|
||||
params.put("base_score", 0.0);
|
||||
|
||||
Map<String, DMatrix> watches = new HashMap<>();
|
||||
watches.put("train", trainingMatrix);
|
||||
watches.put("test", testingMatrix);
|
||||
|
||||
// Train booster on training matrix.
|
||||
Booster booster = XGBoost.train(trainingMatrix, params, 10, watches, null, null);
|
||||
|
||||
// Create a margin
|
||||
float[] margin = new float[testRows];
|
||||
Arrays.fill(margin, 0.5f);
|
||||
|
||||
// Define an iteration range to use all training iterations, this should match
|
||||
// the without margin call
|
||||
// which defines an iteration range of [0,0)
|
||||
int[] iterationRange = new int[] { 0, 0 };
|
||||
|
||||
float[][] inplacePredictionsWithMargin = booster.inplace_predict(testX,
|
||||
testRows,
|
||||
features,
|
||||
Float.NaN,
|
||||
iterationRange,
|
||||
Booster.PredictionType.kValue,
|
||||
margin);
|
||||
float[][] inplacePredictionsWithoutMargin = booster.inplace_predict(testX, testRows, features, Float.NaN);
|
||||
|
||||
for (int i = 0; i < inplacePredictionsWithoutMargin.length; i++) {
|
||||
for (int j = 0; j < inplacePredictionsWithoutMargin[i].length; j++) {
|
||||
inplacePredictionsWithoutMargin[i][j] += margin[j];
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < inplacePredictionsWithoutMargin.length; i++) {
|
||||
assertArrayEquals(inplacePredictionsWithMargin[i], inplacePredictionsWithoutMargin[i], 1e-6f);
|
||||
}
|
||||
}
|
||||
|
||||
private float[] generateRandomDataSet(int size) {
|
||||
float[] newSet = new float[size];
|
||||
Random random = new Random();
|
||||
for(int i = 0; i < size; i++) {
|
||||
newSet[i] = random.nextFloat();
|
||||
}
|
||||
return newSet;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void saveLoadModelWithPath() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
@@ -664,14 +844,12 @@ public class BoosterImplTest {
|
||||
float tempBoosterError = eval.eval(tempBooster.predict(testMat, true, 0), testMat);
|
||||
|
||||
// Save tempBooster to bytestream and load back
|
||||
int prevVersion = tempBooster.getVersion();
|
||||
ByteArrayInputStream in = new ByteArrayInputStream(tempBooster.toByteArray());
|
||||
tempBooster = XGBoost.loadModel(in);
|
||||
in.close();
|
||||
tempBooster.setVersion(prevVersion);
|
||||
|
||||
// Continue training using tempBooster
|
||||
round = 4;
|
||||
round = 2;
|
||||
Booster booster2 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0, tempBooster);
|
||||
float booster2error = eval.eval(booster2.predict(testMat, true, 0), testMat);
|
||||
TestCase.assertTrue(booster1error == booster2error);
|
||||
|
||||
Reference in New Issue
Block a user