[jvm-packages] Add BigDenseMatrix (#4383)
* Add BigDenseMatrix * ability to create DMatrix with bigger than Integer.MAX_VALUE size arrays * uses sun.misc.Unsafe * make DMatrix test work from a jar as well
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
57106a3459
commit
22209b7b95
@@ -15,13 +15,20 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;
|
||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
/**
|
||||
* test cases for DMatrix
|
||||
*
|
||||
@@ -53,7 +60,8 @@ public class DMatrixTest {
|
||||
@Test
|
||||
public void testCreateFromFile() throws XGBoostError {
|
||||
//create DMatrix from file
|
||||
DMatrix dmat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
String filePath = writeResourceIntoTempFile("/agaricus.txt.test");
|
||||
DMatrix dmat = new DMatrix(filePath);
|
||||
//get label
|
||||
float[] labels = dmat.getLabel();
|
||||
//check length
|
||||
@@ -224,6 +232,122 @@ public class DMatrixTest {
|
||||
TestCase.assertTrue(dmat0.getLabel().length == 10);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromDenseMatrixRef() throws XGBoostError {
|
||||
//create DMatrix from 10*5 dense matrix
|
||||
final int nrow = 10;
|
||||
final int ncol = 5;
|
||||
|
||||
DMatrix dmat0 = null;
|
||||
BigDenseMatrix data0 = null;
|
||||
try {
|
||||
data0 = new BigDenseMatrix(nrow, ncol);
|
||||
//put random nums
|
||||
Random random = new Random();
|
||||
for (int i = 0; i < nrow * ncol; i++) {
|
||||
data0.set(i, random.nextFloat());
|
||||
}
|
||||
|
||||
//create label
|
||||
float[] label0 = new float[nrow];
|
||||
for (int i = 0; i < nrow; i++) {
|
||||
label0[i] = random.nextFloat();
|
||||
}
|
||||
|
||||
dmat0 = new DMatrix(data0);
|
||||
dmat0.setLabel(label0);
|
||||
|
||||
//check
|
||||
TestCase.assertTrue(dmat0.rowNum() == 10);
|
||||
TestCase.assertTrue(dmat0.getLabel().length == 10);
|
||||
} finally {
|
||||
if (dmat0 != null) {
|
||||
dmat0.dispose();
|
||||
} else if (data0 != null){
|
||||
data0.dispose();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTrainWithDenseMatrixRef() throws XGBoostError {
|
||||
Map<String, String> rabitEnv = new HashMap<>();
|
||||
rabitEnv.put("DMLC_TASK_ID", "0");
|
||||
Rabit.init(rabitEnv);
|
||||
DMatrix trainMat = null;
|
||||
BigDenseMatrix data0 = null;
|
||||
try {
|
||||
// trivial dataset with 3 rows and 2 columns
|
||||
// (4,5) -> 1
|
||||
// (3,1) -> 2
|
||||
// (2,3) -> 3
|
||||
float[][] data = new float[][]{
|
||||
new float[]{4f, 5f},
|
||||
new float[]{3f, 1f},
|
||||
new float[]{2f, 3f}
|
||||
};
|
||||
data0 = new BigDenseMatrix(3, 2);
|
||||
for (int i = 0; i < data0.nrow; i++)
|
||||
for (int j = 0; j < data0.ncol; j++)
|
||||
data0.set(i, j, data[i][j]);
|
||||
|
||||
trainMat = new DMatrix(data0);
|
||||
trainMat.setLabel(new float[]{1f, 2f, 3f});
|
||||
|
||||
HashMap<String, Object> params = new HashMap<>();
|
||||
params.put("eta", 1);
|
||||
params.put("max_depth", 5);
|
||||
params.put("silent", 1);
|
||||
params.put("objective", "reg:linear");
|
||||
params.put("seed", 123);
|
||||
|
||||
HashMap<String, DMatrix> watches = new HashMap<>();
|
||||
watches.put("train", trainMat);
|
||||
|
||||
Booster booster = XGBoost.train(trainMat, params, 10, watches, null, null);
|
||||
|
||||
// check overfitting
|
||||
// (4,5) -> 1
|
||||
// (3,1) -> 2
|
||||
// (2,3) -> 3
|
||||
for (int i = 0; i < 3; i++) {
|
||||
float[][] preds = booster.predict(new DMatrix(data[i], 1, 2));
|
||||
assertEquals(1, preds.length);
|
||||
assertArrayEquals(new float[]{(float) (i + 1)}, preds[0], 1e-2f);
|
||||
}
|
||||
} finally {
|
||||
if (trainMat != null)
|
||||
trainMat.dispose();
|
||||
else if (data0 != null) {
|
||||
data0.dispose();
|
||||
}
|
||||
Rabit.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
private String writeResourceIntoTempFile(String resource) {
|
||||
InputStream input = getClass().getResourceAsStream(resource);
|
||||
if (input == null) {
|
||||
throw new IllegalArgumentException("Resource " + resource + " does not exist.");
|
||||
}
|
||||
File tmp;
|
||||
try {
|
||||
tmp = File.createTempFile("junit", ".test");
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Unable to write to temp file.", e);
|
||||
}
|
||||
byte[] buff = new byte[1024];
|
||||
try (FileOutputStream output = new FileOutputStream(tmp)) {
|
||||
int n;
|
||||
while ((n = input.read(buff)) > 0) {
|
||||
output.write(buff, 0, n);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Unable to write to temp file.", e);
|
||||
}
|
||||
return tmp.getAbsolutePath();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSetAndGetGroup() throws XGBoostError {
|
||||
//create DMatrix from 10*5 dense matrix
|
||||
|
||||
Reference in New Issue
Block a user