[jvm-packages] Add BigDenseMatrix (#4383)
* Add BigDenseMatrix * ability to create DMatrix with bigger than Integer.MAX_VALUE size arrays * uses sun.misc.Unsafe * make DMatrix test work from a jar as well
This commit is contained in:
parent
57106a3459
commit
22209b7b95
@ -119,3 +119,7 @@ if __name__ == "__main__":
|
||||
cp(file, "xgboost4j-spark/src/test/resources")
|
||||
for file in glob.glob("../demo/data/agaricus.*"):
|
||||
cp(file, "xgboost4j-spark/src/test/resources")
|
||||
|
||||
maybe_makedirs("xgboost4j/src/test/resources")
|
||||
for file in glob.glob("../demo/data/agaricus.*"):
|
||||
cp(file, "xgboost4j/src/test/resources")
|
||||
|
||||
@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.java;
|
||||
import java.util.Iterator;
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;
|
||||
|
||||
/**
|
||||
* DMatrix for xgboost.
|
||||
@ -130,6 +131,16 @@ public class DMatrix {
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* create DMatrix from a BigDenseMatrix
|
||||
*
|
||||
* @param matrix instance of BigDenseMatrix
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public DMatrix(BigDenseMatrix matrix) throws XGBoostError {
|
||||
this(matrix, 0.0f);
|
||||
}
|
||||
|
||||
/**
|
||||
* create DMatrix from dense matrix
|
||||
* @param data data values
|
||||
@ -143,6 +154,18 @@ public class DMatrix {
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* create DMatrix from dense matrix
|
||||
* @param matrix instance of BigDenseMatrix
|
||||
* @param missing the specified value to represent the missing value
|
||||
*/
|
||||
public DMatrix(BigDenseMatrix matrix, float missing) throws XGBoostError {
|
||||
long[] out = new long[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromMatRef(matrix.address, matrix.nrow,
|
||||
matrix.ncol, missing, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* used for DMatrix slice
|
||||
*/
|
||||
|
||||
@ -65,6 +65,9 @@ class XGBoostJNI {
|
||||
public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol,
|
||||
float missing, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromMatRef(long dataRef, int nrow, int ncol,
|
||||
float missing, long[] out);
|
||||
|
||||
public final static native int XGDMatrixSliceDMatrix(long handle, int[] idxset, long[] out);
|
||||
|
||||
public final static native int XGDMatrixFree(long handle);
|
||||
|
||||
@ -0,0 +1,76 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.util;
|
||||
|
||||
/**
|
||||
* Off-heap implementation of a Dense Matrix, matrix size is only limited by the
|
||||
* amount of the available memory and the matrix dimension cannot exceed
|
||||
* Integer.MAX_VALUE (this is consistent with XGBoost API restrictions on maximum
|
||||
* length of a response).
|
||||
*/
|
||||
public final class BigDenseMatrix {
|
||||
|
||||
private static final int FLOAT_BYTE_SIZE = 4;
|
||||
public static final long MAX_MATRIX_SIZE = Long.MAX_VALUE / FLOAT_BYTE_SIZE;
|
||||
|
||||
public final int nrow;
|
||||
public final int ncol;
|
||||
public final long address;
|
||||
|
||||
public static void setDirect(long valAddress, float val) {
|
||||
UtilUnsafe.UNSAFE.putFloat(valAddress, val);
|
||||
}
|
||||
|
||||
public static float getDirect(long valAddress) {
|
||||
return UtilUnsafe.UNSAFE.getFloat(valAddress);
|
||||
}
|
||||
|
||||
public BigDenseMatrix(int nrow, int ncol) {
|
||||
final long size = (long) nrow * ncol;
|
||||
if (size > MAX_MATRIX_SIZE) {
|
||||
throw new IllegalArgumentException("Matrix too large; matrix size cannot exceed " +
|
||||
MAX_MATRIX_SIZE);
|
||||
}
|
||||
this.nrow = nrow;
|
||||
this.ncol = ncol;
|
||||
this.address = UtilUnsafe.UNSAFE.allocateMemory(size * FLOAT_BYTE_SIZE);
|
||||
}
|
||||
|
||||
public final void set(long idx, float val) {
|
||||
setDirect(address + idx * FLOAT_BYTE_SIZE, val);
|
||||
}
|
||||
|
||||
public final void set(int i, int j, float val) {
|
||||
set(index(i, j), val);
|
||||
}
|
||||
|
||||
public final float get(long idx) {
|
||||
return getDirect(address + idx * FLOAT_BYTE_SIZE);
|
||||
}
|
||||
|
||||
public final float get(int i, int j) {
|
||||
return get(index(i, j));
|
||||
}
|
||||
|
||||
public final void dispose() {
|
||||
UtilUnsafe.UNSAFE.freeMemory(address);
|
||||
}
|
||||
|
||||
private long index(int i, int j) {
|
||||
return (long) i * ncol + j;
|
||||
}
|
||||
|
||||
}
|
||||
@ -0,0 +1,46 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.util;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
|
||||
import sun.misc.Unsafe;
|
||||
|
||||
/**
|
||||
* Simple class to obtain access to the {@link Unsafe} object. Use responsibly :)
|
||||
*/
|
||||
public final class UtilUnsafe {
|
||||
|
||||
static Unsafe UNSAFE = getUnsafe();
|
||||
|
||||
private UtilUnsafe() {
|
||||
} // dummy private constructor
|
||||
|
||||
private static Unsafe getUnsafe() {
|
||||
// Not on bootclasspath
|
||||
if (UtilUnsafe.class.getClassLoader() == null) {
|
||||
return Unsafe.getUnsafe();
|
||||
}
|
||||
try {
|
||||
final Field fld = Unsafe.class.getDeclaredField("theUnsafe");
|
||||
fld.setAccessible(true);
|
||||
return (Unsafe) fld.get(UtilUnsafe.class);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@ -247,6 +247,21 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixCreateFromMatRef
|
||||
* Signature: (JIIF)J
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromMatRef
|
||||
(JNIEnv *jenv, jclass jcls, jlong jdataRef, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout) {
|
||||
DMatrixHandle result;
|
||||
bst_ulong nrow = (bst_ulong)jnrow;
|
||||
bst_ulong ncol = (bst_ulong)jncol;
|
||||
jint ret = (jint) XGDMatrixCreateFromMat((float const *)jdataRef, nrow, ncol, jmiss, &result);
|
||||
setHandle(jenv, jout, result);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
|
||||
@ -55,6 +55,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromMat
|
||||
(JNIEnv *, jclass, jfloatArray, jint, jint, jfloat, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixCreateFromMatRef
|
||||
* Signature: (JIIF[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromMatRef
|
||||
(JNIEnv *jenv, jclass jcls, jlong jdataRef, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixSliceDMatrix
|
||||
|
||||
@ -15,13 +15,20 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;
|
||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
/**
|
||||
* test cases for DMatrix
|
||||
*
|
||||
@ -53,7 +60,8 @@ public class DMatrixTest {
|
||||
@Test
|
||||
public void testCreateFromFile() throws XGBoostError {
|
||||
//create DMatrix from file
|
||||
DMatrix dmat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
String filePath = writeResourceIntoTempFile("/agaricus.txt.test");
|
||||
DMatrix dmat = new DMatrix(filePath);
|
||||
//get label
|
||||
float[] labels = dmat.getLabel();
|
||||
//check length
|
||||
@ -224,6 +232,122 @@ public class DMatrixTest {
|
||||
TestCase.assertTrue(dmat0.getLabel().length == 10);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromDenseMatrixRef() throws XGBoostError {
|
||||
//create DMatrix from 10*5 dense matrix
|
||||
final int nrow = 10;
|
||||
final int ncol = 5;
|
||||
|
||||
DMatrix dmat0 = null;
|
||||
BigDenseMatrix data0 = null;
|
||||
try {
|
||||
data0 = new BigDenseMatrix(nrow, ncol);
|
||||
//put random nums
|
||||
Random random = new Random();
|
||||
for (int i = 0; i < nrow * ncol; i++) {
|
||||
data0.set(i, random.nextFloat());
|
||||
}
|
||||
|
||||
//create label
|
||||
float[] label0 = new float[nrow];
|
||||
for (int i = 0; i < nrow; i++) {
|
||||
label0[i] = random.nextFloat();
|
||||
}
|
||||
|
||||
dmat0 = new DMatrix(data0);
|
||||
dmat0.setLabel(label0);
|
||||
|
||||
//check
|
||||
TestCase.assertTrue(dmat0.rowNum() == 10);
|
||||
TestCase.assertTrue(dmat0.getLabel().length == 10);
|
||||
} finally {
|
||||
if (dmat0 != null) {
|
||||
dmat0.dispose();
|
||||
} else if (data0 != null){
|
||||
data0.dispose();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTrainWithDenseMatrixRef() throws XGBoostError {
|
||||
Map<String, String> rabitEnv = new HashMap<>();
|
||||
rabitEnv.put("DMLC_TASK_ID", "0");
|
||||
Rabit.init(rabitEnv);
|
||||
DMatrix trainMat = null;
|
||||
BigDenseMatrix data0 = null;
|
||||
try {
|
||||
// trivial dataset with 3 rows and 2 columns
|
||||
// (4,5) -> 1
|
||||
// (3,1) -> 2
|
||||
// (2,3) -> 3
|
||||
float[][] data = new float[][]{
|
||||
new float[]{4f, 5f},
|
||||
new float[]{3f, 1f},
|
||||
new float[]{2f, 3f}
|
||||
};
|
||||
data0 = new BigDenseMatrix(3, 2);
|
||||
for (int i = 0; i < data0.nrow; i++)
|
||||
for (int j = 0; j < data0.ncol; j++)
|
||||
data0.set(i, j, data[i][j]);
|
||||
|
||||
trainMat = new DMatrix(data0);
|
||||
trainMat.setLabel(new float[]{1f, 2f, 3f});
|
||||
|
||||
HashMap<String, Object> params = new HashMap<>();
|
||||
params.put("eta", 1);
|
||||
params.put("max_depth", 5);
|
||||
params.put("silent", 1);
|
||||
params.put("objective", "reg:linear");
|
||||
params.put("seed", 123);
|
||||
|
||||
HashMap<String, DMatrix> watches = new HashMap<>();
|
||||
watches.put("train", trainMat);
|
||||
|
||||
Booster booster = XGBoost.train(trainMat, params, 10, watches, null, null);
|
||||
|
||||
// check overfitting
|
||||
// (4,5) -> 1
|
||||
// (3,1) -> 2
|
||||
// (2,3) -> 3
|
||||
for (int i = 0; i < 3; i++) {
|
||||
float[][] preds = booster.predict(new DMatrix(data[i], 1, 2));
|
||||
assertEquals(1, preds.length);
|
||||
assertArrayEquals(new float[]{(float) (i + 1)}, preds[0], 1e-2f);
|
||||
}
|
||||
} finally {
|
||||
if (trainMat != null)
|
||||
trainMat.dispose();
|
||||
else if (data0 != null) {
|
||||
data0.dispose();
|
||||
}
|
||||
Rabit.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
private String writeResourceIntoTempFile(String resource) {
|
||||
InputStream input = getClass().getResourceAsStream(resource);
|
||||
if (input == null) {
|
||||
throw new IllegalArgumentException("Resource " + resource + " does not exist.");
|
||||
}
|
||||
File tmp;
|
||||
try {
|
||||
tmp = File.createTempFile("junit", ".test");
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Unable to write to temp file.", e);
|
||||
}
|
||||
byte[] buff = new byte[1024];
|
||||
try (FileOutputStream output = new FileOutputStream(tmp)) {
|
||||
int n;
|
||||
while ((n = input.read(buff)) > 0) {
|
||||
output.write(buff, 0, n);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Unable to write to temp file.", e);
|
||||
}
|
||||
return tmp.getAbsolutePath();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSetAndGetGroup() throws XGBoostError {
|
||||
//create DMatrix from 10*5 dense matrix
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user