Allow JVM-Package to access inplace predict method (#9167)
--------- Co-authored-by: Stephan T. Lavavej <stl@nuwen.net> Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com> Co-authored-by: Joe <25804777+ByteSizedJoe@users.noreply.github.com>
This commit is contained in:
parent
9027686cac
commit
d05ea589fb
@ -39,6 +39,21 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
// handle to the booster.
|
// handle to the booster.
|
||||||
private long handle = 0;
|
private long handle = 0;
|
||||||
private int version = 0;
|
private int version = 0;
|
||||||
|
/**
|
||||||
|
* Type of prediction, used for inplace_predict.
|
||||||
|
*/
|
||||||
|
public enum PredictionType {
|
||||||
|
kValue(0),
|
||||||
|
kMargin(1);
|
||||||
|
|
||||||
|
private Integer ptype;
|
||||||
|
private PredictionType(final Integer ptype) {
|
||||||
|
this.ptype = ptype;
|
||||||
|
}
|
||||||
|
public Integer getPType() {
|
||||||
|
return ptype;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a new Booster with empty stage.
|
* Create a new Booster with empty stage.
|
||||||
@ -375,6 +390,97 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
return predicts;
|
return predicts;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform thread-safe prediction.
|
||||||
|
*
|
||||||
|
* @param data Flattened input matrix of features for prediction
|
||||||
|
* @param nrow The number of preditions to make (count of input matrix rows)
|
||||||
|
* @param ncol The number of features in the model (count of input matrix columns)
|
||||||
|
* @param missing Value indicating missing element in the <code>data</code> input matrix
|
||||||
|
*
|
||||||
|
* @return predict Result matrix
|
||||||
|
*/
|
||||||
|
public float[][] inplace_predict(float[] data,
|
||||||
|
int nrow,
|
||||||
|
int ncol,
|
||||||
|
float missing) throws XGBoostError {
|
||||||
|
int[] iteration_range = new int[2];
|
||||||
|
iteration_range[0] = 0;
|
||||||
|
iteration_range[1] = 0;
|
||||||
|
return this.inplace_predict(data, nrow, ncol,
|
||||||
|
missing, iteration_range, PredictionType.kValue, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform thread-safe prediction.
|
||||||
|
*
|
||||||
|
* @param data Flattened input matrix of features for prediction
|
||||||
|
* @param nrow The number of preditions to make (count of input matrix rows)
|
||||||
|
* @param ncol The number of features in the model (count of input matrix columns)
|
||||||
|
* @param missing Value indicating missing element in the <code>data</code> input matrix
|
||||||
|
* @param iteration_range Specifies which layer of trees are used in prediction. For
|
||||||
|
* example, if a random forest is trained with 100 rounds.
|
||||||
|
* Specifying `iteration_range=[10, 20)`, then only the forests
|
||||||
|
* built during [10, 20) (half open set) rounds are used in this
|
||||||
|
* prediction.
|
||||||
|
*
|
||||||
|
* @return predict Result matrix
|
||||||
|
*/
|
||||||
|
public float[][] inplace_predict(float[] data,
|
||||||
|
int nrow,
|
||||||
|
int ncol,
|
||||||
|
float missing, int[] iteration_range) throws XGBoostError {
|
||||||
|
return this.inplace_predict(data, nrow, ncol,
|
||||||
|
missing, iteration_range, PredictionType.kValue, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform thread-safe prediction.
|
||||||
|
*
|
||||||
|
* @param data Flattened input matrix of features for prediction
|
||||||
|
* @param nrow The number of preditions to make (count of input matrix rows)
|
||||||
|
* @param ncol The number of features in the model (count of input matrix columns)
|
||||||
|
* @param missing Value indicating missing element in the <code>data</code> input matrix
|
||||||
|
* @param iteration_range Specifies which layer of trees are used in prediction. For
|
||||||
|
* example, if a random forest is trained with 100 rounds.
|
||||||
|
* Specifying `iteration_range=[10, 20)`, then only the forests
|
||||||
|
* built during [10, 20) (half open set) rounds are used in this
|
||||||
|
* prediction.
|
||||||
|
* @param predict_type What kind of prediction to run.
|
||||||
|
* @return predict Result matrix
|
||||||
|
*/
|
||||||
|
public float[][] inplace_predict(float[] data,
|
||||||
|
int nrow,
|
||||||
|
int ncol,
|
||||||
|
float missing,
|
||||||
|
int[] iteration_range,
|
||||||
|
PredictionType predict_type,
|
||||||
|
float[] base_margin) throws XGBoostError {
|
||||||
|
if (iteration_range.length != 2) {
|
||||||
|
throw new XGBoostError(new String("Iteration range is expected to be [begin, end)."));
|
||||||
|
}
|
||||||
|
int ptype = predict_type.getPType();
|
||||||
|
|
||||||
|
int begin = iteration_range[0];
|
||||||
|
int end = iteration_range[1];
|
||||||
|
|
||||||
|
float[][] rawPredicts = new float[1][];
|
||||||
|
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterPredictFromDense(handle, data, nrow, ncol,
|
||||||
|
missing,
|
||||||
|
begin, end, ptype, base_margin, rawPredicts));
|
||||||
|
|
||||||
|
int col = rawPredicts[0].length / nrow;
|
||||||
|
float[][] predicts = new float[nrow][col];
|
||||||
|
int r, c;
|
||||||
|
for (int i = 0; i < rawPredicts[0].length; i++) {
|
||||||
|
r = i / col;
|
||||||
|
c = i % col;
|
||||||
|
predicts[r][c] = rawPredicts[0][i];
|
||||||
|
}
|
||||||
|
return predicts;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Predict leaf indices given the data
|
* Predict leaf indices given the data
|
||||||
*
|
*
|
||||||
|
|||||||
@ -119,6 +119,10 @@ class XGBoostJNI {
|
|||||||
public final static native int XGBoosterPredict(long handle, long dmat, int option_mask,
|
public final static native int XGBoosterPredict(long handle, long dmat, int option_mask,
|
||||||
int ntree_limit, float[][] predicts);
|
int ntree_limit, float[][] predicts);
|
||||||
|
|
||||||
|
public final static native int XGBoosterPredictFromDense(long handle, float[] data,
|
||||||
|
long nrow, long ncol, float missing, int iteration_begin, int iteration_end, int predict_type, float[] margin,
|
||||||
|
float[][] predicts);
|
||||||
|
|
||||||
public final static native int XGBoosterLoadModel(long handle, String fname);
|
public final static native int XGBoosterLoadModel(long handle, String fname);
|
||||||
|
|
||||||
public final static native int XGBoosterSaveModel(long handle, String fname);
|
public final static native int XGBoosterSaveModel(long handle, String fname);
|
||||||
@ -154,10 +158,6 @@ class XGBoostJNI {
|
|||||||
public final static native int XGDMatrixSetInfoFromInterface(
|
public final static native int XGDMatrixSetInfoFromInterface(
|
||||||
long handle, String field, String json);
|
long handle, String field, String json);
|
||||||
|
|
||||||
@Deprecated
|
|
||||||
public final static native int XGDeviceQuantileDMatrixCreateFromCallback(
|
|
||||||
java.util.Iterator<ColumnBatch> iter, float missing, int nthread, int maxBin, long[] out);
|
|
||||||
|
|
||||||
public final static native int XGQuantileDMatrixCreateFromCallback(
|
public final static native int XGQuantileDMatrixCreateFromCallback(
|
||||||
java.util.Iterator<ColumnBatch> iter, java.util.Iterator<ColumnBatch> ref, String config, long[] out);
|
java.util.Iterator<ColumnBatch> iter, java.util.Iterator<ColumnBatch> ref, String config, long[] out);
|
||||||
|
|
||||||
|
|||||||
@ -684,6 +684,85 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: XGBoosterPredictFromDense
|
||||||
|
* Signature: (J[FJJFIII[F[[F)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFromDense(
|
||||||
|
JNIEnv *jenv, jclass jcls, jlong jhandle, jfloatArray jdata, jlong num_rows, jlong num_features,
|
||||||
|
jfloat missing, jint iteration_begin, jint iteration_end, jint predict_type,
|
||||||
|
jfloatArray jmargin, jobjectArray jout) {
|
||||||
|
API_BEGIN();
|
||||||
|
BoosterHandle handle = reinterpret_cast<BoosterHandle>(jhandle);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create array interface.
|
||||||
|
*/
|
||||||
|
namespace linalg = xgboost::linalg;
|
||||||
|
jfloat *data = jenv->GetFloatArrayElements(jdata, nullptr);
|
||||||
|
xgboost::Context ctx;
|
||||||
|
auto t_data = linalg::MakeTensorView(
|
||||||
|
ctx.Device(),
|
||||||
|
xgboost::common::Span{data, static_cast<std::size_t>(num_rows * num_features)}, num_rows,
|
||||||
|
num_features);
|
||||||
|
auto s_array = linalg::ArrayInterfaceStr(t_data);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create configuration object.
|
||||||
|
*/
|
||||||
|
xgboost::Json config{xgboost::Object{}};
|
||||||
|
config["cache_id"] = xgboost::Integer{};
|
||||||
|
config["type"] = xgboost::Integer{static_cast<std::int32_t>(predict_type)};
|
||||||
|
config["iteration_begin"] = xgboost::Integer{static_cast<xgboost::bst_layer_t>(iteration_begin)};
|
||||||
|
config["iteration_end"] = xgboost::Integer{static_cast<xgboost::bst_layer_t>(iteration_end)};
|
||||||
|
config["missing"] = xgboost::Number{static_cast<float>(missing)};
|
||||||
|
config["strict_shape"] = xgboost::Boolean{true};
|
||||||
|
std::string s_config;
|
||||||
|
xgboost::Json::Dump(config, &s_config);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle base margin
|
||||||
|
*/
|
||||||
|
BoosterHandle proxy{nullptr};
|
||||||
|
|
||||||
|
float *margin{nullptr};
|
||||||
|
if (jmargin) {
|
||||||
|
margin = jenv->GetFloatArrayElements(jmargin, nullptr);
|
||||||
|
JVM_CHECK_CALL(XGProxyDMatrixCreate(&proxy));
|
||||||
|
JVM_CHECK_CALL(
|
||||||
|
XGDMatrixSetFloatInfo(proxy, "base_margin", margin, jenv->GetArrayLength(jmargin)));
|
||||||
|
}
|
||||||
|
|
||||||
|
bst_ulong const *out_shape;
|
||||||
|
bst_ulong out_dim;
|
||||||
|
float const *result;
|
||||||
|
auto ret = XGBoosterPredictFromDense(handle, s_array.c_str(), s_config.c_str(), proxy, &out_shape,
|
||||||
|
&out_dim, &result);
|
||||||
|
|
||||||
|
jenv->ReleaseFloatArrayElements(jdata, data, 0);
|
||||||
|
if (proxy) {
|
||||||
|
XGDMatrixFree(proxy);
|
||||||
|
jenv->ReleaseFloatArrayElements(jmargin, margin, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ret != 0) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t n{1};
|
||||||
|
for (std::size_t i = 0; i < out_dim; ++i) {
|
||||||
|
n *= out_shape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
jfloatArray jarray = jenv->NewFloatArray(n);
|
||||||
|
|
||||||
|
jenv->SetFloatArrayRegion(jarray, 0, n, result);
|
||||||
|
jenv->SetObjectArrayElement(jout, 0, jarray);
|
||||||
|
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: XGBoosterLoadModel
|
* Method: XGBoosterLoadModel
|
||||||
|
|||||||
@ -207,6 +207,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterEvalOneIt
|
|||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
|
||||||
(JNIEnv *, jclass, jlong, jlong, jint, jint, jobjectArray);
|
(JNIEnv *, jclass, jlong, jlong, jint, jint, jobjectArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: XGBoosterPredictFromDense
|
||||||
|
* Signature: (J[FJJFIII[F[[F)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFromDense
|
||||||
|
(JNIEnv *, jclass, jlong, jfloatArray, jlong, jlong, jfloat, jint, jint, jint, jfloatArray, jobjectArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: XGBoosterLoadModel
|
* Method: XGBoosterLoadModel
|
||||||
@ -359,14 +367,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorAllred
|
|||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface
|
||||||
(JNIEnv *, jclass, jlong, jstring, jstring);
|
(JNIEnv *, jclass, jlong, jstring, jstring);
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
|
||||||
* Method: XGDeviceQuantileDMatrixCreateFromCallback
|
|
||||||
* Signature: (Ljava/util/Iterator;FII[J)I
|
|
||||||
*/
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDMatrixCreateFromCallback
|
|
||||||
(JNIEnv *, jclass, jobject, jfloat, jint, jint, jlongArray);
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: XGQuantileDMatrixCreateFromCallback
|
* Method: XGQuantileDMatrixCreateFromCallback
|
||||||
|
|||||||
@ -15,16 +15,24 @@
|
|||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
import java.io.*;
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
import junit.framework.TestCase;
|
import junit.framework.TestCase;
|
||||||
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import java.io.ByteArrayInputStream;
|
||||||
|
import java.io.ByteArrayOutputStream;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.concurrent.*;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
import static org.junit.Assert.fail;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* test cases for Booster
|
* test cases for Booster Inplace Predict
|
||||||
*
|
*
|
||||||
* @author hzx
|
* @author hzx and Sovrn
|
||||||
*/
|
*/
|
||||||
public class BoosterImplTest {
|
public class BoosterImplTest {
|
||||||
private String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1&format=libsvm";
|
private String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1&format=libsvm";
|
||||||
@ -99,6 +107,179 @@ public class BoosterImplTest {
|
|||||||
TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f);
|
TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void inplacePredictTest() throws XGBoostError {
|
||||||
|
/* Data Generation */
|
||||||
|
// Generate a training set.
|
||||||
|
int trainRows = 1000;
|
||||||
|
int features = 10;
|
||||||
|
int trainSize = trainRows * features;
|
||||||
|
float[] trainX = generateRandomDataSet(trainSize);
|
||||||
|
float[] trainY = generateRandomDataSet(trainRows);
|
||||||
|
|
||||||
|
DMatrix trainingMatrix = new DMatrix(trainX, trainRows, features, Float.NaN);
|
||||||
|
trainingMatrix.setLabel(trainY);
|
||||||
|
|
||||||
|
// Generate a testing set
|
||||||
|
int testRows = 10;
|
||||||
|
int testSize = testRows * features;
|
||||||
|
float[] testX = generateRandomDataSet(testSize);
|
||||||
|
float[] testY = generateRandomDataSet(testRows);
|
||||||
|
|
||||||
|
DMatrix testingMatrix = new DMatrix(testX, testRows, features, Float.NaN);
|
||||||
|
testingMatrix.setLabel(testY);
|
||||||
|
|
||||||
|
/* Training */
|
||||||
|
|
||||||
|
// Set parameters
|
||||||
|
Map<String, Object> params = new HashMap<>();
|
||||||
|
params.put("eta", 1.0);
|
||||||
|
params.put("max_depth",2);
|
||||||
|
params.put("silent", 1);
|
||||||
|
params.put("tree_method", "hist");
|
||||||
|
|
||||||
|
Map<String, DMatrix> watches = new HashMap<>();
|
||||||
|
watches.put("train", trainingMatrix);
|
||||||
|
watches.put("test", testingMatrix);
|
||||||
|
|
||||||
|
Booster booster = XGBoost.train(trainingMatrix, params, 10, watches, null, null);
|
||||||
|
|
||||||
|
/* Prediction */
|
||||||
|
|
||||||
|
// Standard prediction
|
||||||
|
float[][] predictions = booster.predict(testingMatrix);
|
||||||
|
|
||||||
|
// Inplace-prediction
|
||||||
|
float[][] inplacePredictions = booster.inplace_predict(testX, testRows, features, Float.NaN);
|
||||||
|
|
||||||
|
// Confirm that the two prediction results are identical
|
||||||
|
assertArrayEquals(predictions, inplacePredictions);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void inplacePredictMultiPredictTest() throws InterruptedException {
|
||||||
|
// Multithreaded, multiple prediction
|
||||||
|
int trainRows = 1000;
|
||||||
|
int features = 10;
|
||||||
|
int trainSize = trainRows * features;
|
||||||
|
|
||||||
|
int testRows = 10;
|
||||||
|
int testSize = testRows * features;
|
||||||
|
|
||||||
|
//Simulate multiple predictions on multiple random data sets simultaneously.
|
||||||
|
ExecutorService executorService = Executors.newFixedThreadPool(5);
|
||||||
|
int predictsToPerform = 100;
|
||||||
|
for(int i = 0; i < predictsToPerform; i++) {
|
||||||
|
executorService.submit(() -> {
|
||||||
|
try {
|
||||||
|
float[] trainX = generateRandomDataSet(trainSize);
|
||||||
|
float[] trainY = generateRandomDataSet(trainRows);
|
||||||
|
DMatrix trainingMatrix = new DMatrix(trainX, trainRows, features, Float.NaN);
|
||||||
|
trainingMatrix.setLabel(trainY);
|
||||||
|
|
||||||
|
float[] testX = generateRandomDataSet(testSize);
|
||||||
|
float[] testY = generateRandomDataSet(testRows);
|
||||||
|
DMatrix testingMatrix = new DMatrix(testX, testRows, features, Float.NaN);
|
||||||
|
testingMatrix.setLabel(testY);
|
||||||
|
|
||||||
|
Map<String, Object> params = new HashMap<>();
|
||||||
|
params.put("eta", 1.0);
|
||||||
|
params.put("max_depth", 2);
|
||||||
|
params.put("silent", 1);
|
||||||
|
params.put("tree_method", "hist");
|
||||||
|
|
||||||
|
Map<String, DMatrix> watches = new HashMap<>();
|
||||||
|
watches.put("train", trainingMatrix);
|
||||||
|
watches.put("test", testingMatrix);
|
||||||
|
|
||||||
|
Booster booster = XGBoost.train(trainingMatrix, params, 10, watches, null, null);
|
||||||
|
|
||||||
|
float[][] predictions = booster.predict(testingMatrix);
|
||||||
|
float[][] inplacePredictions = booster.inplace_predict(testX, testRows, features, Float.NaN);
|
||||||
|
|
||||||
|
assertArrayEquals(predictions, inplacePredictions);
|
||||||
|
} catch (XGBoostError xgBoostError) {
|
||||||
|
fail(xgBoostError.getMessage());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
executorService.shutdown();
|
||||||
|
if(!executorService.awaitTermination(1, TimeUnit.MINUTES))
|
||||||
|
executorService.shutdownNow();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void inplacePredictWithMarginTest() throws XGBoostError {
|
||||||
|
//Generate a training set
|
||||||
|
int trainRows = 1000;
|
||||||
|
int features = 10;
|
||||||
|
int trainSize = trainRows * features;
|
||||||
|
float[] trainX = generateRandomDataSet(trainSize);
|
||||||
|
float[] trainY = generateRandomDataSet(trainRows);
|
||||||
|
|
||||||
|
DMatrix trainingMatrix = new DMatrix(trainX, trainRows, features, Float.NaN);
|
||||||
|
trainingMatrix.setLabel(trainY);
|
||||||
|
|
||||||
|
// Generate a testing set
|
||||||
|
int testRows = 10;
|
||||||
|
int testSize = testRows * features;
|
||||||
|
float[] testX = generateRandomDataSet(testSize);
|
||||||
|
float[] testY = generateRandomDataSet(testRows);
|
||||||
|
|
||||||
|
DMatrix testingMatrix = new DMatrix(testX, testRows, features, Float.NaN);
|
||||||
|
testingMatrix.setLabel(testY);
|
||||||
|
|
||||||
|
// Set booster parameters
|
||||||
|
Map<String, Object> params = new HashMap<>();
|
||||||
|
params.put("eta", 1.0);
|
||||||
|
params.put("max_depth",2);
|
||||||
|
params.put("tree_method", "hist");
|
||||||
|
params.put("base_score", 0.0);
|
||||||
|
|
||||||
|
Map<String, DMatrix> watches = new HashMap<>();
|
||||||
|
watches.put("train", trainingMatrix);
|
||||||
|
watches.put("test", testingMatrix);
|
||||||
|
|
||||||
|
// Train booster on training matrix.
|
||||||
|
Booster booster = XGBoost.train(trainingMatrix, params, 10, watches, null, null);
|
||||||
|
|
||||||
|
// Create a margin
|
||||||
|
float[] margin = new float[testRows];
|
||||||
|
Arrays.fill(margin, 0.5f);
|
||||||
|
|
||||||
|
// Define an iteration range to use all training iterations, this should match
|
||||||
|
// the without margin call
|
||||||
|
// which defines an iteration range of [0,0)
|
||||||
|
int[] iterationRange = new int[] { 0, 0 };
|
||||||
|
|
||||||
|
float[][] inplacePredictionsWithMargin = booster.inplace_predict(testX,
|
||||||
|
testRows,
|
||||||
|
features,
|
||||||
|
Float.NaN,
|
||||||
|
iterationRange,
|
||||||
|
Booster.PredictionType.kValue,
|
||||||
|
margin);
|
||||||
|
float[][] inplacePredictionsWithoutMargin = booster.inplace_predict(testX, testRows, features, Float.NaN);
|
||||||
|
|
||||||
|
for (int i = 0; i < inplacePredictionsWithoutMargin.length; i++) {
|
||||||
|
for (int j = 0; j < inplacePredictionsWithoutMargin[i].length; j++) {
|
||||||
|
inplacePredictionsWithoutMargin[i][j] += margin[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < inplacePredictionsWithoutMargin.length; i++) {
|
||||||
|
assertArrayEquals(inplacePredictionsWithMargin[i], inplacePredictionsWithoutMargin[i], 1e-6f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private float[] generateRandomDataSet(int size) {
|
||||||
|
float[] newSet = new float[size];
|
||||||
|
Random random = new Random();
|
||||||
|
for(int i = 0; i < size; i++) {
|
||||||
|
newSet[i] = random.nextFloat();
|
||||||
|
}
|
||||||
|
return newSet;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void saveLoadModelWithPath() throws XGBoostError, IOException {
|
public void saveLoadModelWithPath() throws XGBoostError, IOException {
|
||||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user