Allow JVM-Package to access inplace predict method (#9167)

---------

Co-authored-by: Stephan T. Lavavej <stl@nuwen.net>
Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
Co-authored-by: Joe <25804777+ByteSizedJoe@users.noreply.github.com>
This commit is contained in:
Jon Yoquinto
2023-09-11 17:29:51 -06:00
committed by GitHub
parent 9027686cac
commit d05ea589fb
5 changed files with 384 additions and 18 deletions

View File

@@ -15,16 +15,24 @@
*/
package ml.dmlc.xgboost4j.java;
import java.io.*;
import java.util.*;
import junit.framework.TestCase;
import org.junit.Assert;
import org.junit.Test;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.*;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.fail;
/**
* test cases for Booster
*
* @author hzx
* test cases for Booster Inplace Predict
*
* @author hzx and Sovrn
*/
public class BoosterImplTest {
private String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1&format=libsvm";
@@ -99,6 +107,179 @@ public class BoosterImplTest {
TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f);
}
@Test
public void inplacePredictTest() throws XGBoostError {
/* Data Generation */
// Generate a training set.
int trainRows = 1000;
int features = 10;
int trainSize = trainRows * features;
float[] trainX = generateRandomDataSet(trainSize);
float[] trainY = generateRandomDataSet(trainRows);
DMatrix trainingMatrix = new DMatrix(trainX, trainRows, features, Float.NaN);
trainingMatrix.setLabel(trainY);
// Generate a testing set
int testRows = 10;
int testSize = testRows * features;
float[] testX = generateRandomDataSet(testSize);
float[] testY = generateRandomDataSet(testRows);
DMatrix testingMatrix = new DMatrix(testX, testRows, features, Float.NaN);
testingMatrix.setLabel(testY);
/* Training */
// Set parameters
Map<String, Object> params = new HashMap<>();
params.put("eta", 1.0);
params.put("max_depth",2);
params.put("silent", 1);
params.put("tree_method", "hist");
Map<String, DMatrix> watches = new HashMap<>();
watches.put("train", trainingMatrix);
watches.put("test", testingMatrix);
Booster booster = XGBoost.train(trainingMatrix, params, 10, watches, null, null);
/* Prediction */
// Standard prediction
float[][] predictions = booster.predict(testingMatrix);
// Inplace-prediction
float[][] inplacePredictions = booster.inplace_predict(testX, testRows, features, Float.NaN);
// Confirm that the two prediction results are identical
assertArrayEquals(predictions, inplacePredictions);
}
@Test
public void inplacePredictMultiPredictTest() throws InterruptedException {
// Multithreaded, multiple prediction
int trainRows = 1000;
int features = 10;
int trainSize = trainRows * features;
int testRows = 10;
int testSize = testRows * features;
//Simulate multiple predictions on multiple random data sets simultaneously.
ExecutorService executorService = Executors.newFixedThreadPool(5);
int predictsToPerform = 100;
for(int i = 0; i < predictsToPerform; i++) {
executorService.submit(() -> {
try {
float[] trainX = generateRandomDataSet(trainSize);
float[] trainY = generateRandomDataSet(trainRows);
DMatrix trainingMatrix = new DMatrix(trainX, trainRows, features, Float.NaN);
trainingMatrix.setLabel(trainY);
float[] testX = generateRandomDataSet(testSize);
float[] testY = generateRandomDataSet(testRows);
DMatrix testingMatrix = new DMatrix(testX, testRows, features, Float.NaN);
testingMatrix.setLabel(testY);
Map<String, Object> params = new HashMap<>();
params.put("eta", 1.0);
params.put("max_depth", 2);
params.put("silent", 1);
params.put("tree_method", "hist");
Map<String, DMatrix> watches = new HashMap<>();
watches.put("train", trainingMatrix);
watches.put("test", testingMatrix);
Booster booster = XGBoost.train(trainingMatrix, params, 10, watches, null, null);
float[][] predictions = booster.predict(testingMatrix);
float[][] inplacePredictions = booster.inplace_predict(testX, testRows, features, Float.NaN);
assertArrayEquals(predictions, inplacePredictions);
} catch (XGBoostError xgBoostError) {
fail(xgBoostError.getMessage());
}
});
}
executorService.shutdown();
if(!executorService.awaitTermination(1, TimeUnit.MINUTES))
executorService.shutdownNow();
}
@Test
public void inplacePredictWithMarginTest() throws XGBoostError {
//Generate a training set
int trainRows = 1000;
int features = 10;
int trainSize = trainRows * features;
float[] trainX = generateRandomDataSet(trainSize);
float[] trainY = generateRandomDataSet(trainRows);
DMatrix trainingMatrix = new DMatrix(trainX, trainRows, features, Float.NaN);
trainingMatrix.setLabel(trainY);
// Generate a testing set
int testRows = 10;
int testSize = testRows * features;
float[] testX = generateRandomDataSet(testSize);
float[] testY = generateRandomDataSet(testRows);
DMatrix testingMatrix = new DMatrix(testX, testRows, features, Float.NaN);
testingMatrix.setLabel(testY);
// Set booster parameters
Map<String, Object> params = new HashMap<>();
params.put("eta", 1.0);
params.put("max_depth",2);
params.put("tree_method", "hist");
params.put("base_score", 0.0);
Map<String, DMatrix> watches = new HashMap<>();
watches.put("train", trainingMatrix);
watches.put("test", testingMatrix);
// Train booster on training matrix.
Booster booster = XGBoost.train(trainingMatrix, params, 10, watches, null, null);
// Create a margin
float[] margin = new float[testRows];
Arrays.fill(margin, 0.5f);
// Define an iteration range to use all training iterations, this should match
// the without margin call
// which defines an iteration range of [0,0)
int[] iterationRange = new int[] { 0, 0 };
float[][] inplacePredictionsWithMargin = booster.inplace_predict(testX,
testRows,
features,
Float.NaN,
iterationRange,
Booster.PredictionType.kValue,
margin);
float[][] inplacePredictionsWithoutMargin = booster.inplace_predict(testX, testRows, features, Float.NaN);
for (int i = 0; i < inplacePredictionsWithoutMargin.length; i++) {
for (int j = 0; j < inplacePredictionsWithoutMargin[i].length; j++) {
inplacePredictionsWithoutMargin[i][j] += margin[j];
}
}
for (int i = 0; i < inplacePredictionsWithoutMargin.length; i++) {
assertArrayEquals(inplacePredictionsWithMargin[i], inplacePredictionsWithoutMargin[i], 1e-6f);
}
}
private float[] generateRandomDataSet(int size) {
float[] newSet = new float[size];
Random random = new Random();
for(int i = 0; i < size; i++) {
newSet[i] = random.nextFloat();
}
return newSet;
}
@Test
public void saveLoadModelWithPath() throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix(this.train_uri);