[jvm-packages] Add BigDenseMatrix (#4383)

* Add BigDenseMatrix

* ability to create DMatrix with bigger than Integer.MAX_VALUE size arrays
* uses sun.misc.Unsafe

* make DMatrix test work from a jar as well
This commit is contained in:
Honza Sterba 2019-09-19 05:46:14 +02:00 committed by Philip Hyunsu Cho
parent 57106a3459
commit 22209b7b95
8 changed files with 300 additions and 1 deletions

View File

@ -119,3 +119,7 @@ if __name__ == "__main__":
cp(file, "xgboost4j-spark/src/test/resources")
for file in glob.glob("../demo/data/agaricus.*"):
cp(file, "xgboost4j-spark/src/test/resources")
maybe_makedirs("xgboost4j/src/test/resources")
for file in glob.glob("../demo/data/agaricus.*"):
cp(file, "xgboost4j/src/test/resources")

View File

@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.java;
import java.util.Iterator;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;
/**
* DMatrix for xgboost.
@ -130,6 +131,16 @@ public class DMatrix {
handle = out[0];
}
/**
* create DMatrix from a BigDenseMatrix
*
* @param matrix instance of BigDenseMatrix
* @throws XGBoostError native error
*/
public DMatrix(BigDenseMatrix matrix) throws XGBoostError {
this(matrix, 0.0f);
}
/**
* create DMatrix from dense matrix
* @param data data values
@ -143,6 +154,18 @@ public class DMatrix {
handle = out[0];
}
/**
* create DMatrix from dense matrix
* @param matrix instance of BigDenseMatrix
* @param missing the specified value to represent the missing value
*/
public DMatrix(BigDenseMatrix matrix, float missing) throws XGBoostError {
long[] out = new long[1];
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromMatRef(matrix.address, matrix.nrow,
matrix.ncol, missing, out));
handle = out[0];
}
/**
* used for DMatrix slice
*/

View File

@ -65,6 +65,9 @@ class XGBoostJNI {
public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol,
float missing, long[] out);
public final static native int XGDMatrixCreateFromMatRef(long dataRef, int nrow, int ncol,
float missing, long[] out);
public final static native int XGDMatrixSliceDMatrix(long handle, int[] idxset, long[] out);
public final static native int XGDMatrixFree(long handle);

View File

@ -0,0 +1,76 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java.util;
/**
* Off-heap implementation of a Dense Matrix, matrix size is only limited by the
* amount of the available memory and the matrix dimension cannot exceed
* Integer.MAX_VALUE (this is consistent with XGBoost API restrictions on maximum
* length of a response).
*/
public final class BigDenseMatrix {
private static final int FLOAT_BYTE_SIZE = 4;
public static final long MAX_MATRIX_SIZE = Long.MAX_VALUE / FLOAT_BYTE_SIZE;
public final int nrow;
public final int ncol;
public final long address;
public static void setDirect(long valAddress, float val) {
UtilUnsafe.UNSAFE.putFloat(valAddress, val);
}
public static float getDirect(long valAddress) {
return UtilUnsafe.UNSAFE.getFloat(valAddress);
}
public BigDenseMatrix(int nrow, int ncol) {
final long size = (long) nrow * ncol;
if (size > MAX_MATRIX_SIZE) {
throw new IllegalArgumentException("Matrix too large; matrix size cannot exceed " +
MAX_MATRIX_SIZE);
}
this.nrow = nrow;
this.ncol = ncol;
this.address = UtilUnsafe.UNSAFE.allocateMemory(size * FLOAT_BYTE_SIZE);
}
public final void set(long idx, float val) {
setDirect(address + idx * FLOAT_BYTE_SIZE, val);
}
public final void set(int i, int j, float val) {
set(index(i, j), val);
}
public final float get(long idx) {
return getDirect(address + idx * FLOAT_BYTE_SIZE);
}
public final float get(int i, int j) {
return get(index(i, j));
}
public final void dispose() {
UtilUnsafe.UNSAFE.freeMemory(address);
}
private long index(int i, int j) {
return (long) i * ncol + j;
}
}

View File

@ -0,0 +1,46 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java.util;
import java.lang.reflect.Field;
import sun.misc.Unsafe;
/**
* Simple class to obtain access to the {@link Unsafe} object. Use responsibly :)
*/
public final class UtilUnsafe {
static Unsafe UNSAFE = getUnsafe();
private UtilUnsafe() {
} // dummy private constructor
private static Unsafe getUnsafe() {
// Not on bootclasspath
if (UtilUnsafe.class.getClassLoader() == null) {
return Unsafe.getUnsafe();
}
try {
final Field fld = Unsafe.class.getDeclaredField("theUnsafe");
fld.setAccessible(true);
return (Unsafe) fld.get(UtilUnsafe.class);
} catch (Exception e) {
throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e);
}
}
}

View File

@ -247,6 +247,21 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromMatRef
* Signature: (JIIF)J
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromMatRef
(JNIEnv *jenv, jclass jcls, jlong jdataRef, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout) {
DMatrixHandle result;
bst_ulong nrow = (bst_ulong)jnrow;
bst_ulong ncol = (bst_ulong)jncol;
jint ret = (jint) XGDMatrixCreateFromMat((float const *)jdataRef, nrow, ncol, jmiss, &result);
setHandle(jenv, jout, result);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI

View File

@ -55,6 +55,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromMat
(JNIEnv *, jclass, jfloatArray, jint, jint, jfloat, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromMatRef
* Signature: (JIIF[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromMatRef
(JNIEnv *jenv, jclass jcls, jlong jdataRef, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSliceDMatrix

View File

@ -15,13 +15,20 @@
*/
package ml.dmlc.xgboost4j.java;
import java.io.*;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import junit.framework.TestCase;
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;
import ml.dmlc.xgboost4j.LabeledPoint;
import org.junit.Test;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
/**
* test cases for DMatrix
*
@ -53,7 +60,8 @@ public class DMatrixTest {
@Test
public void testCreateFromFile() throws XGBoostError {
//create DMatrix from file
DMatrix dmat = new DMatrix("../../demo/data/agaricus.txt.test");
String filePath = writeResourceIntoTempFile("/agaricus.txt.test");
DMatrix dmat = new DMatrix(filePath);
//get label
float[] labels = dmat.getLabel();
//check length
@ -224,6 +232,122 @@ public class DMatrixTest {
TestCase.assertTrue(dmat0.getLabel().length == 10);
}
@Test
public void testCreateFromDenseMatrixRef() throws XGBoostError {
//create DMatrix from 10*5 dense matrix
final int nrow = 10;
final int ncol = 5;
DMatrix dmat0 = null;
BigDenseMatrix data0 = null;
try {
data0 = new BigDenseMatrix(nrow, ncol);
//put random nums
Random random = new Random();
for (int i = 0; i < nrow * ncol; i++) {
data0.set(i, random.nextFloat());
}
//create label
float[] label0 = new float[nrow];
for (int i = 0; i < nrow; i++) {
label0[i] = random.nextFloat();
}
dmat0 = new DMatrix(data0);
dmat0.setLabel(label0);
//check
TestCase.assertTrue(dmat0.rowNum() == 10);
TestCase.assertTrue(dmat0.getLabel().length == 10);
} finally {
if (dmat0 != null) {
dmat0.dispose();
} else if (data0 != null){
data0.dispose();
}
}
}
@Test
public void testTrainWithDenseMatrixRef() throws XGBoostError {
Map<String, String> rabitEnv = new HashMap<>();
rabitEnv.put("DMLC_TASK_ID", "0");
Rabit.init(rabitEnv);
DMatrix trainMat = null;
BigDenseMatrix data0 = null;
try {
// trivial dataset with 3 rows and 2 columns
// (4,5) -> 1
// (3,1) -> 2
// (2,3) -> 3
float[][] data = new float[][]{
new float[]{4f, 5f},
new float[]{3f, 1f},
new float[]{2f, 3f}
};
data0 = new BigDenseMatrix(3, 2);
for (int i = 0; i < data0.nrow; i++)
for (int j = 0; j < data0.ncol; j++)
data0.set(i, j, data[i][j]);
trainMat = new DMatrix(data0);
trainMat.setLabel(new float[]{1f, 2f, 3f});
HashMap<String, Object> params = new HashMap<>();
params.put("eta", 1);
params.put("max_depth", 5);
params.put("silent", 1);
params.put("objective", "reg:linear");
params.put("seed", 123);
HashMap<String, DMatrix> watches = new HashMap<>();
watches.put("train", trainMat);
Booster booster = XGBoost.train(trainMat, params, 10, watches, null, null);
// check overfitting
// (4,5) -> 1
// (3,1) -> 2
// (2,3) -> 3
for (int i = 0; i < 3; i++) {
float[][] preds = booster.predict(new DMatrix(data[i], 1, 2));
assertEquals(1, preds.length);
assertArrayEquals(new float[]{(float) (i + 1)}, preds[0], 1e-2f);
}
} finally {
if (trainMat != null)
trainMat.dispose();
else if (data0 != null) {
data0.dispose();
}
Rabit.shutdown();
}
}
private String writeResourceIntoTempFile(String resource) {
InputStream input = getClass().getResourceAsStream(resource);
if (input == null) {
throw new IllegalArgumentException("Resource " + resource + " does not exist.");
}
File tmp;
try {
tmp = File.createTempFile("junit", ".test");
} catch (IOException e) {
throw new RuntimeException("Unable to write to temp file.", e);
}
byte[] buff = new byte[1024];
try (FileOutputStream output = new FileOutputStream(tmp)) {
int n;
while ((n = input.read(buff)) > 0) {
output.write(buff, 0, n);
}
} catch (IOException e) {
throw new RuntimeException("Unable to write to temp file.", e);
}
return tmp.getAbsolutePath();
}
@Test
public void testSetAndGetGroup() throws XGBoostError {
//create DMatrix from 10*5 dense matrix