Allow JVM-Package to access inplace predict method (#9167)
--------- Co-authored-by: Stephan T. Lavavej <stl@nuwen.net> Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com> Co-authored-by: Joe <25804777+ByteSizedJoe@users.noreply.github.com>
This commit is contained in:
@@ -15,16 +15,24 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.*;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.*;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
/**
|
||||
* test cases for Booster
|
||||
*
|
||||
* @author hzx
|
||||
* test cases for Booster Inplace Predict
|
||||
*
|
||||
* @author hzx and Sovrn
|
||||
*/
|
||||
public class BoosterImplTest {
|
||||
private String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1&format=libsvm";
|
||||
@@ -99,6 +107,179 @@ public class BoosterImplTest {
|
||||
TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inplacePredictTest() throws XGBoostError {
|
||||
/* Data Generation */
|
||||
// Generate a training set.
|
||||
int trainRows = 1000;
|
||||
int features = 10;
|
||||
int trainSize = trainRows * features;
|
||||
float[] trainX = generateRandomDataSet(trainSize);
|
||||
float[] trainY = generateRandomDataSet(trainRows);
|
||||
|
||||
DMatrix trainingMatrix = new DMatrix(trainX, trainRows, features, Float.NaN);
|
||||
trainingMatrix.setLabel(trainY);
|
||||
|
||||
// Generate a testing set
|
||||
int testRows = 10;
|
||||
int testSize = testRows * features;
|
||||
float[] testX = generateRandomDataSet(testSize);
|
||||
float[] testY = generateRandomDataSet(testRows);
|
||||
|
||||
DMatrix testingMatrix = new DMatrix(testX, testRows, features, Float.NaN);
|
||||
testingMatrix.setLabel(testY);
|
||||
|
||||
/* Training */
|
||||
|
||||
// Set parameters
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth",2);
|
||||
params.put("silent", 1);
|
||||
params.put("tree_method", "hist");
|
||||
|
||||
Map<String, DMatrix> watches = new HashMap<>();
|
||||
watches.put("train", trainingMatrix);
|
||||
watches.put("test", testingMatrix);
|
||||
|
||||
Booster booster = XGBoost.train(trainingMatrix, params, 10, watches, null, null);
|
||||
|
||||
/* Prediction */
|
||||
|
||||
// Standard prediction
|
||||
float[][] predictions = booster.predict(testingMatrix);
|
||||
|
||||
// Inplace-prediction
|
||||
float[][] inplacePredictions = booster.inplace_predict(testX, testRows, features, Float.NaN);
|
||||
|
||||
// Confirm that the two prediction results are identical
|
||||
assertArrayEquals(predictions, inplacePredictions);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inplacePredictMultiPredictTest() throws InterruptedException {
|
||||
// Multithreaded, multiple prediction
|
||||
int trainRows = 1000;
|
||||
int features = 10;
|
||||
int trainSize = trainRows * features;
|
||||
|
||||
int testRows = 10;
|
||||
int testSize = testRows * features;
|
||||
|
||||
//Simulate multiple predictions on multiple random data sets simultaneously.
|
||||
ExecutorService executorService = Executors.newFixedThreadPool(5);
|
||||
int predictsToPerform = 100;
|
||||
for(int i = 0; i < predictsToPerform; i++) {
|
||||
executorService.submit(() -> {
|
||||
try {
|
||||
float[] trainX = generateRandomDataSet(trainSize);
|
||||
float[] trainY = generateRandomDataSet(trainRows);
|
||||
DMatrix trainingMatrix = new DMatrix(trainX, trainRows, features, Float.NaN);
|
||||
trainingMatrix.setLabel(trainY);
|
||||
|
||||
float[] testX = generateRandomDataSet(testSize);
|
||||
float[] testY = generateRandomDataSet(testRows);
|
||||
DMatrix testingMatrix = new DMatrix(testX, testRows, features, Float.NaN);
|
||||
testingMatrix.setLabel(testY);
|
||||
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
params.put("tree_method", "hist");
|
||||
|
||||
Map<String, DMatrix> watches = new HashMap<>();
|
||||
watches.put("train", trainingMatrix);
|
||||
watches.put("test", testingMatrix);
|
||||
|
||||
Booster booster = XGBoost.train(trainingMatrix, params, 10, watches, null, null);
|
||||
|
||||
float[][] predictions = booster.predict(testingMatrix);
|
||||
float[][] inplacePredictions = booster.inplace_predict(testX, testRows, features, Float.NaN);
|
||||
|
||||
assertArrayEquals(predictions, inplacePredictions);
|
||||
} catch (XGBoostError xgBoostError) {
|
||||
fail(xgBoostError.getMessage());
|
||||
}
|
||||
});
|
||||
}
|
||||
executorService.shutdown();
|
||||
if(!executorService.awaitTermination(1, TimeUnit.MINUTES))
|
||||
executorService.shutdownNow();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inplacePredictWithMarginTest() throws XGBoostError {
|
||||
//Generate a training set
|
||||
int trainRows = 1000;
|
||||
int features = 10;
|
||||
int trainSize = trainRows * features;
|
||||
float[] trainX = generateRandomDataSet(trainSize);
|
||||
float[] trainY = generateRandomDataSet(trainRows);
|
||||
|
||||
DMatrix trainingMatrix = new DMatrix(trainX, trainRows, features, Float.NaN);
|
||||
trainingMatrix.setLabel(trainY);
|
||||
|
||||
// Generate a testing set
|
||||
int testRows = 10;
|
||||
int testSize = testRows * features;
|
||||
float[] testX = generateRandomDataSet(testSize);
|
||||
float[] testY = generateRandomDataSet(testRows);
|
||||
|
||||
DMatrix testingMatrix = new DMatrix(testX, testRows, features, Float.NaN);
|
||||
testingMatrix.setLabel(testY);
|
||||
|
||||
// Set booster parameters
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth",2);
|
||||
params.put("tree_method", "hist");
|
||||
params.put("base_score", 0.0);
|
||||
|
||||
Map<String, DMatrix> watches = new HashMap<>();
|
||||
watches.put("train", trainingMatrix);
|
||||
watches.put("test", testingMatrix);
|
||||
|
||||
// Train booster on training matrix.
|
||||
Booster booster = XGBoost.train(trainingMatrix, params, 10, watches, null, null);
|
||||
|
||||
// Create a margin
|
||||
float[] margin = new float[testRows];
|
||||
Arrays.fill(margin, 0.5f);
|
||||
|
||||
// Define an iteration range to use all training iterations, this should match
|
||||
// the without margin call
|
||||
// which defines an iteration range of [0,0)
|
||||
int[] iterationRange = new int[] { 0, 0 };
|
||||
|
||||
float[][] inplacePredictionsWithMargin = booster.inplace_predict(testX,
|
||||
testRows,
|
||||
features,
|
||||
Float.NaN,
|
||||
iterationRange,
|
||||
Booster.PredictionType.kValue,
|
||||
margin);
|
||||
float[][] inplacePredictionsWithoutMargin = booster.inplace_predict(testX, testRows, features, Float.NaN);
|
||||
|
||||
for (int i = 0; i < inplacePredictionsWithoutMargin.length; i++) {
|
||||
for (int j = 0; j < inplacePredictionsWithoutMargin[i].length; j++) {
|
||||
inplacePredictionsWithoutMargin[i][j] += margin[j];
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < inplacePredictionsWithoutMargin.length; i++) {
|
||||
assertArrayEquals(inplacePredictionsWithMargin[i], inplacePredictionsWithoutMargin[i], 1e-6f);
|
||||
}
|
||||
}
|
||||
|
||||
private float[] generateRandomDataSet(int size) {
|
||||
float[] newSet = new float[size];
|
||||
Random random = new Random();
|
||||
for(int i = 0; i < size; i++) {
|
||||
newSet[i] = random.nextFloat();
|
||||
}
|
||||
return newSet;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void saveLoadModelWithPath() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
|
||||
Reference in New Issue
Block a user