[jvm-packages] Add BigDenseMatrix (#4383)

* Add BigDenseMatrix

* ability to create DMatrix with bigger than Integer.MAX_VALUE size arrays
* uses sun.misc.Unsafe

* make DMatrix test work from a jar as well
This commit is contained in:
Honza Sterba
2019-09-19 05:46:14 +02:00
committed by Philip Hyunsu Cho
parent 57106a3459
commit 22209b7b95
8 changed files with 300 additions and 1 deletions

View File

@@ -15,13 +15,20 @@
*/
package ml.dmlc.xgboost4j.java;
import java.io.*;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import junit.framework.TestCase;
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;
import ml.dmlc.xgboost4j.LabeledPoint;
import org.junit.Test;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
/**
* test cases for DMatrix
*
@@ -53,7 +60,8 @@ public class DMatrixTest {
@Test
public void testCreateFromFile() throws XGBoostError {
//create DMatrix from file
DMatrix dmat = new DMatrix("../../demo/data/agaricus.txt.test");
String filePath = writeResourceIntoTempFile("/agaricus.txt.test");
DMatrix dmat = new DMatrix(filePath);
//get label
float[] labels = dmat.getLabel();
//check length
@@ -224,6 +232,122 @@ public class DMatrixTest {
TestCase.assertTrue(dmat0.getLabel().length == 10);
}
@Test
public void testCreateFromDenseMatrixRef() throws XGBoostError {
//create DMatrix from 10*5 dense matrix
final int nrow = 10;
final int ncol = 5;
DMatrix dmat0 = null;
BigDenseMatrix data0 = null;
try {
data0 = new BigDenseMatrix(nrow, ncol);
//put random nums
Random random = new Random();
for (int i = 0; i < nrow * ncol; i++) {
data0.set(i, random.nextFloat());
}
//create label
float[] label0 = new float[nrow];
for (int i = 0; i < nrow; i++) {
label0[i] = random.nextFloat();
}
dmat0 = new DMatrix(data0);
dmat0.setLabel(label0);
//check
TestCase.assertTrue(dmat0.rowNum() == 10);
TestCase.assertTrue(dmat0.getLabel().length == 10);
} finally {
if (dmat0 != null) {
dmat0.dispose();
} else if (data0 != null){
data0.dispose();
}
}
}
@Test
public void testTrainWithDenseMatrixRef() throws XGBoostError {
Map<String, String> rabitEnv = new HashMap<>();
rabitEnv.put("DMLC_TASK_ID", "0");
Rabit.init(rabitEnv);
DMatrix trainMat = null;
BigDenseMatrix data0 = null;
try {
// trivial dataset with 3 rows and 2 columns
// (4,5) -> 1
// (3,1) -> 2
// (2,3) -> 3
float[][] data = new float[][]{
new float[]{4f, 5f},
new float[]{3f, 1f},
new float[]{2f, 3f}
};
data0 = new BigDenseMatrix(3, 2);
for (int i = 0; i < data0.nrow; i++)
for (int j = 0; j < data0.ncol; j++)
data0.set(i, j, data[i][j]);
trainMat = new DMatrix(data0);
trainMat.setLabel(new float[]{1f, 2f, 3f});
HashMap<String, Object> params = new HashMap<>();
params.put("eta", 1);
params.put("max_depth", 5);
params.put("silent", 1);
params.put("objective", "reg:linear");
params.put("seed", 123);
HashMap<String, DMatrix> watches = new HashMap<>();
watches.put("train", trainMat);
Booster booster = XGBoost.train(trainMat, params, 10, watches, null, null);
// check overfitting
// (4,5) -> 1
// (3,1) -> 2
// (2,3) -> 3
for (int i = 0; i < 3; i++) {
float[][] preds = booster.predict(new DMatrix(data[i], 1, 2));
assertEquals(1, preds.length);
assertArrayEquals(new float[]{(float) (i + 1)}, preds[0], 1e-2f);
}
} finally {
if (trainMat != null)
trainMat.dispose();
else if (data0 != null) {
data0.dispose();
}
Rabit.shutdown();
}
}
private String writeResourceIntoTempFile(String resource) {
InputStream input = getClass().getResourceAsStream(resource);
if (input == null) {
throw new IllegalArgumentException("Resource " + resource + " does not exist.");
}
File tmp;
try {
tmp = File.createTempFile("junit", ".test");
} catch (IOException e) {
throw new RuntimeException("Unable to write to temp file.", e);
}
byte[] buff = new byte[1024];
try (FileOutputStream output = new FileOutputStream(tmp)) {
int n;
while ((n = input.read(buff)) > 0) {
output.write(buff, 0, n);
}
} catch (IOException e) {
throw new RuntimeException("Unable to write to temp file.", e);
}
return tmp.getAbsolutePath();
}
@Test
public void testSetAndGetGroup() throws XGBoostError {
//create DMatrix from 10*5 dense matrix