enable ROCm on latest XGBoost

This commit is contained in:
Hui Liu
2023-10-23 11:07:08 -07:00
328 changed files with 8028 additions and 3642 deletions

View File

@@ -4,16 +4,16 @@ list(APPEND JVM_SOURCES
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j/src/native/xgboost4j.cpp
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cpp)
if (USE_CUDA)
if(USE_CUDA)
list(APPEND JVM_SOURCES
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu)
endif (USE_CUDA)
endif()
add_library(xgboost4j SHARED ${JVM_SOURCES} ${XGBOOST_OBJ_SOURCES})
if (ENABLE_ALL_WARNINGS)
if(ENABLE_ALL_WARNINGS)
target_compile_options(xgboost4j PUBLIC -Wall -Wextra)
endif (ENABLE_ALL_WARNINGS)
endif()
target_link_libraries(xgboost4j PRIVATE objxgboost)
target_include_directories(xgboost4j

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -32,57 +32,53 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
}
private def createNewModels():
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
val (model4, model8) = {
val (model2, model4) = {
val training = buildDataFrame(Classification.train)
val paramMap = produceParamMap(tmpPath, 2)
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
}
(tmpPath, model4, model8)
(tmpPath, model2, model4)
}
test("test update/load models") {
val (tmpPath, model4, model8) = createNewModels()
val (tmpPath, model2, model4) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model4._booster.booster)
manager.updateCheckpoint(model2._booster.booster)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "4.model")
assert(manager.loadCheckpointAsScalaBooster().getVersion == 4)
assert(files.head.getPath.getName == "1.model")
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 2)
manager.updateCheckpoint(model8._booster)
manager.updateCheckpoint(model4._booster)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
assert(manager.loadCheckpointAsScalaBooster().getVersion == 8)
assert(files.head.getPath.getName == "3.model")
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4)
}
test("test cleanUpHigherVersions") {
val (tmpPath, model4, model8) = createNewModels()
val (tmpPath, model2, model4) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model8._booster)
manager.cleanUpHigherVersions(8)
assert(new File(s"$tmpPath/8.model").exists())
manager.updateCheckpoint(model4._booster)
manager.cleanUpHigherVersions(3)
assert(new File(s"$tmpPath/3.model").exists())
manager.cleanUpHigherVersions(4)
assert(!new File(s"$tmpPath/8.model").exists())
manager.cleanUpHigherVersions(2)
assert(!new File(s"$tmpPath/3.model").exists())
}
test("test checkpoint rounds") {
import scala.collection.JavaConverters._
val (tmpPath, model4, model8) = createNewModels()
val (tmpPath, model2, model4) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
assertResult(Seq(7))(
manager.getCheckpointRounds(0, 7).asScala)
assertResult(Seq(2, 4, 6, 7))(
manager.getCheckpointRounds(2, 7).asScala)
manager.updateCheckpoint(model4._booster)
assertResult(Seq(4, 6, 7))(
manager.getCheckpointRounds(2, 7).asScala)
assertResult(Seq(2))(manager.getCheckpointRounds(0, 0, 3).asScala)
assertResult(Seq(0, 2, 4, 6))(manager.getCheckpointRounds(0, 2, 7).asScala)
assertResult(Seq(0, 2, 4, 6, 7))(manager.getCheckpointRounds(0, 2, 8).asScala)
}
@@ -109,8 +105,8 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
// Check only one model is kept after training
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
assert(files.head.getPath.getName == "4.model")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.model")
// Train next model based on prev model
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
assert(error(tmpModel) >= error(prevModel._booster))

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -39,6 +39,21 @@ public class Booster implements Serializable, KryoSerializable {
// handle to the booster.
private long handle = 0;
private int version = 0;
/**
* Type of prediction, used for inplace_predict.
*/
public enum PredictionType {
kValue(0),
kMargin(1);
private Integer ptype;
private PredictionType(final Integer ptype) {
this.ptype = ptype;
}
public Integer getPType() {
return ptype;
}
}
/**
* Create a new Booster with empty stage.
@@ -375,6 +390,97 @@ public class Booster implements Serializable, KryoSerializable {
return predicts;
}
/**
* Perform thread-safe prediction.
*
* @param data Flattened input matrix of features for prediction
* @param nrow The number of preditions to make (count of input matrix rows)
* @param ncol The number of features in the model (count of input matrix columns)
* @param missing Value indicating missing element in the <code>data</code> input matrix
*
* @return predict Result matrix
*/
public float[][] inplace_predict(float[] data,
int nrow,
int ncol,
float missing) throws XGBoostError {
int[] iteration_range = new int[2];
iteration_range[0] = 0;
iteration_range[1] = 0;
return this.inplace_predict(data, nrow, ncol,
missing, iteration_range, PredictionType.kValue, null);
}
/**
* Perform thread-safe prediction.
*
* @param data Flattened input matrix of features for prediction
* @param nrow The number of preditions to make (count of input matrix rows)
* @param ncol The number of features in the model (count of input matrix columns)
* @param missing Value indicating missing element in the <code>data</code> input matrix
* @param iteration_range Specifies which layer of trees are used in prediction. For
* example, if a random forest is trained with 100 rounds.
* Specifying `iteration_range=[10, 20)`, then only the forests
* built during [10, 20) (half open set) rounds are used in this
* prediction.
*
* @return predict Result matrix
*/
public float[][] inplace_predict(float[] data,
int nrow,
int ncol,
float missing, int[] iteration_range) throws XGBoostError {
return this.inplace_predict(data, nrow, ncol,
missing, iteration_range, PredictionType.kValue, null);
}
/**
* Perform thread-safe prediction.
*
* @param data Flattened input matrix of features for prediction
* @param nrow The number of preditions to make (count of input matrix rows)
* @param ncol The number of features in the model (count of input matrix columns)
* @param missing Value indicating missing element in the <code>data</code> input matrix
* @param iteration_range Specifies which layer of trees are used in prediction. For
* example, if a random forest is trained with 100 rounds.
* Specifying `iteration_range=[10, 20)`, then only the forests
* built during [10, 20) (half open set) rounds are used in this
* prediction.
* @param predict_type What kind of prediction to run.
* @return predict Result matrix
*/
public float[][] inplace_predict(float[] data,
int nrow,
int ncol,
float missing,
int[] iteration_range,
PredictionType predict_type,
float[] base_margin) throws XGBoostError {
if (iteration_range.length != 2) {
throw new XGBoostError(new String("Iteration range is expected to be [begin, end)."));
}
int ptype = predict_type.getPType();
int begin = iteration_range[0];
int end = iteration_range[1];
float[][] rawPredicts = new float[1][];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterPredictFromDense(handle, data, nrow, ncol,
missing,
begin, end, ptype, base_margin, rawPredicts));
int col = rawPredicts[0].length / nrow;
float[][] predicts = new float[nrow][col];
int r, c;
for (int i = 0; i < rawPredicts[0].length; i++) {
r = i / col;
c = i % col;
predicts[r][c] = rawPredicts[0][i];
}
return predicts;
}
/**
* Predict leaf indices given the data
*
@@ -681,35 +787,6 @@ public class Booster implements Serializable, KryoSerializable {
return importanceMap;
}
/**
* Save the model as byte array representation.
* Write these bytes to a file will give compatible format with other xgboost bindings.
*
* If java natively support HDFS file API, use toByteArray and write the ByteArray
*
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
* @throws XGBoostError native error
*/
private String[] getDumpInfo(boolean withStats) throws XGBoostError {
int statsFlag = 0;
if (withStats) {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelEx(handle, "", statsFlag, "text",
modelInfos));
return modelInfos[0];
}
public int getVersion() {
return this.version;
}
public void setVersion(int version) {
this.version = version;
}
/**
* Save model into raw byte array. Currently it's using the deprecated format as
* default, which will be changed into `ubj` in future releases.
@@ -735,29 +812,6 @@ public class Booster implements Serializable, KryoSerializable {
return bytes[0];
}
/**
* Load the booster model from thread-local rabit checkpoint.
* This is only used in distributed training.
* @return the stored version number of the checkpoint.
* @throws XGBoostError
*/
int loadRabitCheckpoint() throws XGBoostError {
int[] out = new int[1];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
version = out[0];
return version;
}
/**
* Save the booster model into thread-local rabit checkpoint and increment the version.
* This is only used in distributed training.
* @throws XGBoostError
*/
void saveRabitCheckpoint() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
version += 1;
}
/**
* Get number of model features.
* @return the number of features.
@@ -768,6 +822,11 @@ public class Booster implements Serializable, KryoSerializable {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, numFeature));
return numFeature[0];
}
public int getNumBoostedRound() throws XGBoostError {
int[] numRound = new int[1];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumBoostedRound(this.handle, numRound));
return numRound[0];
}
/**
* Internal initialization function.

View File

@@ -1,3 +1,18 @@
/*
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;
import java.io.IOException;
@@ -15,7 +30,7 @@ public class ExternalCheckpointManager {
private Log logger = LogFactory.getLog("ExternalCheckpointManager");
private String modelSuffix = ".model";
private Path checkpointPath;
private Path checkpointPath; // directory for checkpoints
private FileSystem fs;
public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
@@ -35,6 +50,7 @@ public class ExternalCheckpointManager {
if (!fs.exists(checkpointPath)) {
return new ArrayList<>();
} else {
// Get integer versions from a list of checkpoint files.
return Arrays.stream(fs.listStatus(checkpointPath))
.map(path -> path.getPath().getName())
.filter(fileName -> fileName.endsWith(modelSuffix))
@@ -44,6 +60,11 @@ public class ExternalCheckpointManager {
}
}
private Integer latest(List<Integer> versions) {
return versions.stream()
.max(Comparator.comparing(Integer::valueOf)).get();
}
public void cleanPath() throws IOException {
fs.delete(checkpointPath, true);
}
@@ -51,12 +72,11 @@ public class ExternalCheckpointManager {
public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
List<Integer> versions = getExistingVersions();
if (versions.size() > 0) {
int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
int latestVersion = this.latest(versions);
String checkpointPath = getPath(latestVersion);
InputStream in = fs.open(new Path(checkpointPath));
logger.info("loaded checkpoint from " + checkpointPath);
Booster booster = XGBoost.loadModel(in);
booster.setVersion(latestVersion);
return booster;
} else {
return null;
@@ -65,13 +85,16 @@ public class ExternalCheckpointManager {
public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
List<String> prevModelPaths = getExistingVersions().stream()
.map(this::getPath).collect(Collectors.toList());
String eventualPath = getPath(boosterToCheckpoint.getVersion());
.map(this::getPath).collect(Collectors.toList());
// checkpointing is done after update, so n_rounds - 1 is the current iteration
// accounting for training continuation.
Integer iter = boosterToCheckpoint.getNumBoostedRound() - 1;
String eventualPath = getPath(iter);
String tempPath = eventualPath + "-" + UUID.randomUUID();
try (OutputStream out = fs.create(new Path(tempPath), true)) {
boosterToCheckpoint.saveModel(out);
fs.rename(new Path(tempPath), new Path(eventualPath));
logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
logger.info("saving checkpoint with version " + iter);
prevModelPaths.stream().forEach(path -> {
try {
fs.delete(new Path(path), true);
@@ -83,7 +106,7 @@ public class ExternalCheckpointManager {
}
public void cleanUpHigherVersions(int currentRound) throws IOException {
getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
getExistingVersions().stream().filter(v -> v > currentRound).forEach(v -> {
try {
fs.delete(new Path(getPath(v)), true);
} catch (IOException e) {
@@ -91,27 +114,26 @@ public class ExternalCheckpointManager {
}
});
}
public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds)
// Get a list of iterations that need checkpointing.
public List<Integer> getCheckpointRounds(
int firstRound, int checkpointInterval, int numOfRounds)
throws IOException {
int end = firstRound + numOfRounds; // exclusive
int lastRound = end - 1;
if (end - 1 < 0) {
throw new IllegalArgumentException("Inavlid `numOfRounds`.");
}
List<Integer> arr = new ArrayList<>();
if (checkpointInterval > 0) {
List<Integer> prevRounds =
getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
prevRounds.add(0);
int firstCheckpointRound = prevRounds.stream()
.max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
List<Integer> arr = new ArrayList<>();
for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) {
for (int i = firstRound; i < end; i += checkpointInterval) {
arr.add(i);
}
arr.add(numOfRounds);
return arr;
} else if (checkpointInterval <= 0) {
List<Integer> l = new ArrayList<Integer>();
l.add(numOfRounds);
return l;
} else {
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
}
if (!arr.contains(lastRound)) {
arr.add(lastRound);
}
return arr;
}
}

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014,2021 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -133,7 +133,7 @@ public class XGBoost {
int earlyStoppingRound) throws XGBoostError {
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
}
// save checkpoint if iter is in checkpointIterations
private static void saveCheckpoint(
Booster booster,
int iter,
@@ -169,7 +169,6 @@ public class XGBoost {
int bestIteration;
List<String> names = new ArrayList<String>();
List<DMatrix> mats = new ArrayList<DMatrix>();
Set<Integer> checkpointIterations = new HashSet<>();
ExternalCheckpointManager ecm = null;
if (checkpointPath != null) {
ecm = new ExternalCheckpointManager(checkpointPath, fs);
@@ -203,32 +202,30 @@ public class XGBoost {
booster = new Booster(params, allMats);
booster.setFeatureNames(dtrain.getFeatureNames());
booster.setFeatureTypes(dtrain.getFeatureTypes());
booster.loadRabitCheckpoint();
} else {
// Start training on an existing booster
booster.setParams(params);
}
Set<Integer> checkpointIterations = new HashSet<>();
if (ecm != null) {
checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
checkpointIterations = new HashSet<>(
ecm.getCheckpointRounds(booster.getNumBoostedRound(), checkpointInterval, numRounds));
}
boolean initial_best_score_flag = false;
boolean max_direction = false;
// begin to train
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
if (booster.getVersion() % 2 == 0) {
if (obj != null) {
booster.update(dtrain, obj);
} else {
booster.update(dtrain, iter);
}
saveCheckpoint(booster, iter, checkpointIterations, ecm);
booster.saveRabitCheckpoint();
for (int iter = 0; iter < numRounds; iter++) {
if (obj != null) {
booster.update(dtrain, iter, obj);
} else {
booster.update(dtrain, iter);
}
saveCheckpoint(booster, iter, checkpointIterations, ecm);
//evaluation
// evaluation
if (evalMats.length > 0) {
float[] metricsOut = new float[evalMats.length];
String evalInfo;
@@ -285,7 +282,6 @@ public class XGBoost {
Communicator.communicatorPrint(evalInfo + '\n');
}
}
booster.saveRabitCheckpoint();
}
return booster;
}

View File

@@ -119,6 +119,10 @@ class XGBoostJNI {
public final static native int XGBoosterPredict(long handle, long dmat, int option_mask,
int ntree_limit, float[][] predicts);
public final static native int XGBoosterPredictFromDense(long handle, float[] data,
long nrow, long ncol, float missing, int iteration_begin, int iteration_end, int predict_type, float[] margin,
float[][] predicts);
public final static native int XGBoosterLoadModel(long handle, String fname);
public final static native int XGBoosterSaveModel(long handle, String fname);
@@ -136,10 +140,11 @@ class XGBoostJNI {
public final static native int XGBoosterGetAttrNames(long handle, String[][] out_strings);
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
public final static native int XGBoosterSetAttr(long handle, String key, String value);
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
public final static native int XGBoosterGetNumFeature(long handle, long[] feature);
public final static native int XGBoosterGetNumBoostedRound(long handle, int[] rounds);
// communicator functions
public final static native int CommunicatorInit(String[] args);
public final static native int CommunicatorFinalize();
@@ -154,10 +159,6 @@ class XGBoostJNI {
public final static native int XGDMatrixSetInfoFromInterface(
long handle, String field, String json);
@Deprecated
public final static native int XGDeviceQuantileDMatrixCreateFromCallback(
java.util.Iterator<ColumnBatch> iter, float missing, int nthread, int maxBin, long[] out);
public final static native int XGQuantileDMatrixCreateFromCallback(
java.util.Iterator<ColumnBatch> iter, java.util.Iterator<ColumnBatch> ref, String config, long[] out);

View File

@@ -326,7 +326,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
@throws(classOf[XGBoostError])
def getNumFeature: Long = booster.getNumFeature
def getVersion: Int = booster.getVersion
def getNumBoostedRound: Long = booster.getNumBoostedRound
/**
* Save model into a raw byte array. Available options are "json", "ubj" and "deprecated".

View File

@@ -684,6 +684,85 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterPredictFromDense
* Signature: (J[FJJFIII[F[[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFromDense(
JNIEnv *jenv, jclass jcls, jlong jhandle, jfloatArray jdata, jlong num_rows, jlong num_features,
jfloat missing, jint iteration_begin, jint iteration_end, jint predict_type,
jfloatArray jmargin, jobjectArray jout) {
API_BEGIN();
BoosterHandle handle = reinterpret_cast<BoosterHandle>(jhandle);
/**
* Create array interface.
*/
namespace linalg = xgboost::linalg;
jfloat *data = jenv->GetFloatArrayElements(jdata, nullptr);
xgboost::Context ctx;
auto t_data = linalg::MakeTensorView(
ctx.Device(),
xgboost::common::Span{data, static_cast<std::size_t>(num_rows * num_features)}, num_rows,
num_features);
auto s_array = linalg::ArrayInterfaceStr(t_data);
/**
* Create configuration object.
*/
xgboost::Json config{xgboost::Object{}};
config["cache_id"] = xgboost::Integer{};
config["type"] = xgboost::Integer{static_cast<std::int32_t>(predict_type)};
config["iteration_begin"] = xgboost::Integer{static_cast<xgboost::bst_layer_t>(iteration_begin)};
config["iteration_end"] = xgboost::Integer{static_cast<xgboost::bst_layer_t>(iteration_end)};
config["missing"] = xgboost::Number{static_cast<float>(missing)};
config["strict_shape"] = xgboost::Boolean{true};
std::string s_config;
xgboost::Json::Dump(config, &s_config);
/**
* Handle base margin
*/
BoosterHandle proxy{nullptr};
float *margin{nullptr};
if (jmargin) {
margin = jenv->GetFloatArrayElements(jmargin, nullptr);
JVM_CHECK_CALL(XGProxyDMatrixCreate(&proxy));
JVM_CHECK_CALL(
XGDMatrixSetFloatInfo(proxy, "base_margin", margin, jenv->GetArrayLength(jmargin)));
}
bst_ulong const *out_shape;
bst_ulong out_dim;
float const *result;
auto ret = XGBoosterPredictFromDense(handle, s_array.c_str(), s_config.c_str(), proxy, &out_shape,
&out_dim, &result);
jenv->ReleaseFloatArrayElements(jdata, data, 0);
if (proxy) {
XGDMatrixFree(proxy);
jenv->ReleaseFloatArrayElements(jmargin, margin, 0);
}
if (ret != 0) {
return ret;
}
std::size_t n{1};
for (std::size_t i = 0; i < out_dim; ++i) {
n *= out_shape[i];
}
jfloatArray jarray = jenv->NewFloatArray(n);
jenv->SetFloatArrayRegion(jarray, 0, n, result);
jenv->SetObjectArrayElement(jout, 0, jarray);
API_END();
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadModel
@@ -905,33 +984,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadRabitCheckpoint
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint
(JNIEnv *jenv , jclass jcls, jlong jhandle, jintArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
int version;
int ret = XGBoosterLoadRabitCheckpoint(handle, &version);
JVM_CHECK_CALL(ret);
jint jversion = version;
jenv->SetIntArrayRegion(jout, 0, 1, &jversion);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSaveRabitCheckpoint
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
BoosterHandle handle = (BoosterHandle) jhandle;
return XGBoosterSaveRabitCheckpoint(handle);
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetNumFeature
@@ -948,6 +1000,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFea
return ret;
}
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoostedRound(
JNIEnv *jenv, jclass, jlong jhandle, jintArray jout) {
BoosterHandle handle = (BoosterHandle)jhandle;
std::int32_t n_rounds{0};
auto ret = XGBoosterBoostedRounds(handle, &n_rounds);
JVM_CHECK_CALL(ret);
jint jn_rounds = n_rounds;
jenv->SetIntArrayRegion(jout, 0, 1, &jn_rounds);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorInit

View File

@@ -207,6 +207,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterEvalOneIt
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
(JNIEnv *, jclass, jlong, jlong, jint, jint, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterPredictFromDense
* Signature: (J[FJJFIII[F[[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFromDense
(JNIEnv *, jclass, jlong, jfloatArray, jlong, jlong, jfloat, jint, jint, jint, jfloatArray, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadModel
@@ -279,22 +287,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr
(JNIEnv *, jclass, jlong, jstring, jstring);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadRabitCheckpoint
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSaveRabitCheckpoint
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
(JNIEnv *, jclass, jlong);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetNumFeature
@@ -303,6 +295,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabit
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature
(JNIEnv *, jclass, jlong, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetNumBoostedRound
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoostedRound
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorInit
@@ -359,14 +359,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorAllred
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface
(JNIEnv *, jclass, jlong, jstring, jstring);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDeviceQuantileDMatrixCreateFromCallback
* Signature: (Ljava/util/Iterator;FII[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDMatrixCreateFromCallback
(JNIEnv *, jclass, jobject, jfloat, jint, jint, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGQuantileDMatrixCreateFromCallback

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -15,16 +15,23 @@
*/
package ml.dmlc.xgboost4j.java;
import java.io.*;
import java.util.*;
import junit.framework.TestCase;
import org.junit.Test;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.*;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.fail;
/**
* test cases for Booster
* test cases for Booster Inplace Predict
*
* @author hzx
* @author hzx and Sovrn
*/
public class BoosterImplTest {
private String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1&format=libsvm";
@@ -99,6 +106,179 @@ public class BoosterImplTest {
TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f);
}
@Test
public void inplacePredictTest() throws XGBoostError {
/* Data Generation */
// Generate a training set.
int trainRows = 1000;
int features = 10;
int trainSize = trainRows * features;
float[] trainX = generateRandomDataSet(trainSize);
float[] trainY = generateRandomDataSet(trainRows);
DMatrix trainingMatrix = new DMatrix(trainX, trainRows, features, Float.NaN);
trainingMatrix.setLabel(trainY);
// Generate a testing set
int testRows = 10;
int testSize = testRows * features;
float[] testX = generateRandomDataSet(testSize);
float[] testY = generateRandomDataSet(testRows);
DMatrix testingMatrix = new DMatrix(testX, testRows, features, Float.NaN);
testingMatrix.setLabel(testY);
/* Training */
// Set parameters
Map<String, Object> params = new HashMap<>();
params.put("eta", 1.0);
params.put("max_depth",2);
params.put("silent", 1);
params.put("tree_method", "hist");
Map<String, DMatrix> watches = new HashMap<>();
watches.put("train", trainingMatrix);
watches.put("test", testingMatrix);
Booster booster = XGBoost.train(trainingMatrix, params, 10, watches, null, null);
/* Prediction */
// Standard prediction
float[][] predictions = booster.predict(testingMatrix);
// Inplace-prediction
float[][] inplacePredictions = booster.inplace_predict(testX, testRows, features, Float.NaN);
// Confirm that the two prediction results are identical
assertArrayEquals(predictions, inplacePredictions);
}
@Test
public void inplacePredictMultiPredictTest() throws InterruptedException {
// Multithreaded, multiple prediction
int trainRows = 1000;
int features = 10;
int trainSize = trainRows * features;
int testRows = 10;
int testSize = testRows * features;
//Simulate multiple predictions on multiple random data sets simultaneously.
ExecutorService executorService = Executors.newFixedThreadPool(5);
int predictsToPerform = 100;
for(int i = 0; i < predictsToPerform; i++) {
executorService.submit(() -> {
try {
float[] trainX = generateRandomDataSet(trainSize);
float[] trainY = generateRandomDataSet(trainRows);
DMatrix trainingMatrix = new DMatrix(trainX, trainRows, features, Float.NaN);
trainingMatrix.setLabel(trainY);
float[] testX = generateRandomDataSet(testSize);
float[] testY = generateRandomDataSet(testRows);
DMatrix testingMatrix = new DMatrix(testX, testRows, features, Float.NaN);
testingMatrix.setLabel(testY);
Map<String, Object> params = new HashMap<>();
params.put("eta", 1.0);
params.put("max_depth", 2);
params.put("silent", 1);
params.put("tree_method", "hist");
Map<String, DMatrix> watches = new HashMap<>();
watches.put("train", trainingMatrix);
watches.put("test", testingMatrix);
Booster booster = XGBoost.train(trainingMatrix, params, 10, watches, null, null);
float[][] predictions = booster.predict(testingMatrix);
float[][] inplacePredictions = booster.inplace_predict(testX, testRows, features, Float.NaN);
assertArrayEquals(predictions, inplacePredictions);
} catch (XGBoostError xgBoostError) {
fail(xgBoostError.getMessage());
}
});
}
executorService.shutdown();
if(!executorService.awaitTermination(1, TimeUnit.MINUTES))
executorService.shutdownNow();
}
@Test
public void inplacePredictWithMarginTest() throws XGBoostError {
//Generate a training set
int trainRows = 1000;
int features = 10;
int trainSize = trainRows * features;
float[] trainX = generateRandomDataSet(trainSize);
float[] trainY = generateRandomDataSet(trainRows);
DMatrix trainingMatrix = new DMatrix(trainX, trainRows, features, Float.NaN);
trainingMatrix.setLabel(trainY);
// Generate a testing set
int testRows = 10;
int testSize = testRows * features;
float[] testX = generateRandomDataSet(testSize);
float[] testY = generateRandomDataSet(testRows);
DMatrix testingMatrix = new DMatrix(testX, testRows, features, Float.NaN);
testingMatrix.setLabel(testY);
// Set booster parameters
Map<String, Object> params = new HashMap<>();
params.put("eta", 1.0);
params.put("max_depth",2);
params.put("tree_method", "hist");
params.put("base_score", 0.0);
Map<String, DMatrix> watches = new HashMap<>();
watches.put("train", trainingMatrix);
watches.put("test", testingMatrix);
// Train booster on training matrix.
Booster booster = XGBoost.train(trainingMatrix, params, 10, watches, null, null);
// Create a margin
float[] margin = new float[testRows];
Arrays.fill(margin, 0.5f);
// Define an iteration range to use all training iterations, this should match
// the without margin call
// which defines an iteration range of [0,0)
int[] iterationRange = new int[] { 0, 0 };
float[][] inplacePredictionsWithMargin = booster.inplace_predict(testX,
testRows,
features,
Float.NaN,
iterationRange,
Booster.PredictionType.kValue,
margin);
float[][] inplacePredictionsWithoutMargin = booster.inplace_predict(testX, testRows, features, Float.NaN);
for (int i = 0; i < inplacePredictionsWithoutMargin.length; i++) {
for (int j = 0; j < inplacePredictionsWithoutMargin[i].length; j++) {
inplacePredictionsWithoutMargin[i][j] += margin[j];
}
}
for (int i = 0; i < inplacePredictionsWithoutMargin.length; i++) {
assertArrayEquals(inplacePredictionsWithMargin[i], inplacePredictionsWithoutMargin[i], 1e-6f);
}
}
private float[] generateRandomDataSet(int size) {
float[] newSet = new float[size];
Random random = new Random();
for(int i = 0; i < size; i++) {
newSet[i] = random.nextFloat();
}
return newSet;
}
@Test
public void saveLoadModelWithPath() throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix(this.train_uri);
@@ -664,14 +844,12 @@ public class BoosterImplTest {
float tempBoosterError = eval.eval(tempBooster.predict(testMat, true, 0), testMat);
// Save tempBooster to bytestream and load back
int prevVersion = tempBooster.getVersion();
ByteArrayInputStream in = new ByteArrayInputStream(tempBooster.toByteArray());
tempBooster = XGBoost.loadModel(in);
in.close();
tempBooster.setVersion(prevVersion);
// Continue training using tempBooster
round = 4;
round = 2;
Booster booster2 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0, tempBooster);
float booster2error = eval.eval(booster2.predict(testMat, true, 0), testMat);
TestCase.assertTrue(booster1error == booster2error);