[jvm-packages] Add BigDenseMatrix (#4383)
* Add BigDenseMatrix * ability to create DMatrix with bigger than Integer.MAX_VALUE size arrays * uses sun.misc.Unsafe * make DMatrix test work from a jar as well
This commit is contained in:
parent
57106a3459
commit
22209b7b95
@ -119,3 +119,7 @@ if __name__ == "__main__":
|
|||||||
cp(file, "xgboost4j-spark/src/test/resources")
|
cp(file, "xgboost4j-spark/src/test/resources")
|
||||||
for file in glob.glob("../demo/data/agaricus.*"):
|
for file in glob.glob("../demo/data/agaricus.*"):
|
||||||
cp(file, "xgboost4j-spark/src/test/resources")
|
cp(file, "xgboost4j-spark/src/test/resources")
|
||||||
|
|
||||||
|
maybe_makedirs("xgboost4j/src/test/resources")
|
||||||
|
for file in glob.glob("../demo/data/agaricus.*"):
|
||||||
|
cp(file, "xgboost4j/src/test/resources")
|
||||||
|
|||||||
@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.java;
|
|||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||||
|
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* DMatrix for xgboost.
|
* DMatrix for xgboost.
|
||||||
@ -130,6 +131,16 @@ public class DMatrix {
|
|||||||
handle = out[0];
|
handle = out[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* create DMatrix from a BigDenseMatrix
|
||||||
|
*
|
||||||
|
* @param matrix instance of BigDenseMatrix
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
public DMatrix(BigDenseMatrix matrix) throws XGBoostError {
|
||||||
|
this(matrix, 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* create DMatrix from dense matrix
|
* create DMatrix from dense matrix
|
||||||
* @param data data values
|
* @param data data values
|
||||||
@ -143,6 +154,18 @@ public class DMatrix {
|
|||||||
handle = out[0];
|
handle = out[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* create DMatrix from dense matrix
|
||||||
|
* @param matrix instance of BigDenseMatrix
|
||||||
|
* @param missing the specified value to represent the missing value
|
||||||
|
*/
|
||||||
|
public DMatrix(BigDenseMatrix matrix, float missing) throws XGBoostError {
|
||||||
|
long[] out = new long[1];
|
||||||
|
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromMatRef(matrix.address, matrix.nrow,
|
||||||
|
matrix.ncol, missing, out));
|
||||||
|
handle = out[0];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* used for DMatrix slice
|
* used for DMatrix slice
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -65,6 +65,9 @@ class XGBoostJNI {
|
|||||||
public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol,
|
public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol,
|
||||||
float missing, long[] out);
|
float missing, long[] out);
|
||||||
|
|
||||||
|
public final static native int XGDMatrixCreateFromMatRef(long dataRef, int nrow, int ncol,
|
||||||
|
float missing, long[] out);
|
||||||
|
|
||||||
public final static native int XGDMatrixSliceDMatrix(long handle, int[] idxset, long[] out);
|
public final static native int XGDMatrixSliceDMatrix(long handle, int[] idxset, long[] out);
|
||||||
|
|
||||||
public final static native int XGDMatrixFree(long handle);
|
public final static native int XGDMatrixFree(long handle);
|
||||||
|
|||||||
@ -0,0 +1,76 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
package ml.dmlc.xgboost4j.java.util;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Off-heap implementation of a Dense Matrix, matrix size is only limited by the
|
||||||
|
* amount of the available memory and the matrix dimension cannot exceed
|
||||||
|
* Integer.MAX_VALUE (this is consistent with XGBoost API restrictions on maximum
|
||||||
|
* length of a response).
|
||||||
|
*/
|
||||||
|
public final class BigDenseMatrix {
|
||||||
|
|
||||||
|
private static final int FLOAT_BYTE_SIZE = 4;
|
||||||
|
public static final long MAX_MATRIX_SIZE = Long.MAX_VALUE / FLOAT_BYTE_SIZE;
|
||||||
|
|
||||||
|
public final int nrow;
|
||||||
|
public final int ncol;
|
||||||
|
public final long address;
|
||||||
|
|
||||||
|
public static void setDirect(long valAddress, float val) {
|
||||||
|
UtilUnsafe.UNSAFE.putFloat(valAddress, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static float getDirect(long valAddress) {
|
||||||
|
return UtilUnsafe.UNSAFE.getFloat(valAddress);
|
||||||
|
}
|
||||||
|
|
||||||
|
public BigDenseMatrix(int nrow, int ncol) {
|
||||||
|
final long size = (long) nrow * ncol;
|
||||||
|
if (size > MAX_MATRIX_SIZE) {
|
||||||
|
throw new IllegalArgumentException("Matrix too large; matrix size cannot exceed " +
|
||||||
|
MAX_MATRIX_SIZE);
|
||||||
|
}
|
||||||
|
this.nrow = nrow;
|
||||||
|
this.ncol = ncol;
|
||||||
|
this.address = UtilUnsafe.UNSAFE.allocateMemory(size * FLOAT_BYTE_SIZE);
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void set(long idx, float val) {
|
||||||
|
setDirect(address + idx * FLOAT_BYTE_SIZE, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void set(int i, int j, float val) {
|
||||||
|
set(index(i, j), val);
|
||||||
|
}
|
||||||
|
|
||||||
|
public final float get(long idx) {
|
||||||
|
return getDirect(address + idx * FLOAT_BYTE_SIZE);
|
||||||
|
}
|
||||||
|
|
||||||
|
public final float get(int i, int j) {
|
||||||
|
return get(index(i, j));
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void dispose() {
|
||||||
|
UtilUnsafe.UNSAFE.freeMemory(address);
|
||||||
|
}
|
||||||
|
|
||||||
|
private long index(int i, int j) {
|
||||||
|
return (long) i * ncol + j;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@ -0,0 +1,46 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
package ml.dmlc.xgboost4j.java.util;
|
||||||
|
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
|
||||||
|
import sun.misc.Unsafe;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Simple class to obtain access to the {@link Unsafe} object. Use responsibly :)
|
||||||
|
*/
|
||||||
|
public final class UtilUnsafe {
|
||||||
|
|
||||||
|
static Unsafe UNSAFE = getUnsafe();
|
||||||
|
|
||||||
|
private UtilUnsafe() {
|
||||||
|
} // dummy private constructor
|
||||||
|
|
||||||
|
private static Unsafe getUnsafe() {
|
||||||
|
// Not on bootclasspath
|
||||||
|
if (UtilUnsafe.class.getClassLoader() == null) {
|
||||||
|
return Unsafe.getUnsafe();
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
final Field fld = Unsafe.class.getDeclaredField("theUnsafe");
|
||||||
|
fld.setAccessible(true);
|
||||||
|
return (Unsafe) fld.get(UtilUnsafe.class);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@ -247,6 +247,21 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: XGDMatrixCreateFromMatRef
|
||||||
|
* Signature: (JIIF)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromMatRef
|
||||||
|
(JNIEnv *jenv, jclass jcls, jlong jdataRef, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout) {
|
||||||
|
DMatrixHandle result;
|
||||||
|
bst_ulong nrow = (bst_ulong)jnrow;
|
||||||
|
bst_ulong ncol = (bst_ulong)jncol;
|
||||||
|
jint ret = (jint) XGDMatrixCreateFromMat((float const *)jdataRef, nrow, ncol, jmiss, &result);
|
||||||
|
setHandle(jenv, jout, result);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
|||||||
@ -55,6 +55,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
|
|||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromMat
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromMat
|
||||||
(JNIEnv *, jclass, jfloatArray, jint, jint, jfloat, jlongArray);
|
(JNIEnv *, jclass, jfloatArray, jint, jint, jfloat, jlongArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: XGDMatrixCreateFromMatRef
|
||||||
|
* Signature: (JIIF[J)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromMatRef
|
||||||
|
(JNIEnv *jenv, jclass jcls, jlong jdataRef, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: XGDMatrixSliceDMatrix
|
* Method: XGDMatrixSliceDMatrix
|
||||||
|
|||||||
@ -15,13 +15,20 @@
|
|||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
|
import java.io.*;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import junit.framework.TestCase;
|
import junit.framework.TestCase;
|
||||||
|
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;
|
||||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* test cases for DMatrix
|
* test cases for DMatrix
|
||||||
*
|
*
|
||||||
@ -53,7 +60,8 @@ public class DMatrixTest {
|
|||||||
@Test
|
@Test
|
||||||
public void testCreateFromFile() throws XGBoostError {
|
public void testCreateFromFile() throws XGBoostError {
|
||||||
//create DMatrix from file
|
//create DMatrix from file
|
||||||
DMatrix dmat = new DMatrix("../../demo/data/agaricus.txt.test");
|
String filePath = writeResourceIntoTempFile("/agaricus.txt.test");
|
||||||
|
DMatrix dmat = new DMatrix(filePath);
|
||||||
//get label
|
//get label
|
||||||
float[] labels = dmat.getLabel();
|
float[] labels = dmat.getLabel();
|
||||||
//check length
|
//check length
|
||||||
@ -224,6 +232,122 @@ public class DMatrixTest {
|
|||||||
TestCase.assertTrue(dmat0.getLabel().length == 10);
|
TestCase.assertTrue(dmat0.getLabel().length == 10);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCreateFromDenseMatrixRef() throws XGBoostError {
|
||||||
|
//create DMatrix from 10*5 dense matrix
|
||||||
|
final int nrow = 10;
|
||||||
|
final int ncol = 5;
|
||||||
|
|
||||||
|
DMatrix dmat0 = null;
|
||||||
|
BigDenseMatrix data0 = null;
|
||||||
|
try {
|
||||||
|
data0 = new BigDenseMatrix(nrow, ncol);
|
||||||
|
//put random nums
|
||||||
|
Random random = new Random();
|
||||||
|
for (int i = 0; i < nrow * ncol; i++) {
|
||||||
|
data0.set(i, random.nextFloat());
|
||||||
|
}
|
||||||
|
|
||||||
|
//create label
|
||||||
|
float[] label0 = new float[nrow];
|
||||||
|
for (int i = 0; i < nrow; i++) {
|
||||||
|
label0[i] = random.nextFloat();
|
||||||
|
}
|
||||||
|
|
||||||
|
dmat0 = new DMatrix(data0);
|
||||||
|
dmat0.setLabel(label0);
|
||||||
|
|
||||||
|
//check
|
||||||
|
TestCase.assertTrue(dmat0.rowNum() == 10);
|
||||||
|
TestCase.assertTrue(dmat0.getLabel().length == 10);
|
||||||
|
} finally {
|
||||||
|
if (dmat0 != null) {
|
||||||
|
dmat0.dispose();
|
||||||
|
} else if (data0 != null){
|
||||||
|
data0.dispose();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTrainWithDenseMatrixRef() throws XGBoostError {
|
||||||
|
Map<String, String> rabitEnv = new HashMap<>();
|
||||||
|
rabitEnv.put("DMLC_TASK_ID", "0");
|
||||||
|
Rabit.init(rabitEnv);
|
||||||
|
DMatrix trainMat = null;
|
||||||
|
BigDenseMatrix data0 = null;
|
||||||
|
try {
|
||||||
|
// trivial dataset with 3 rows and 2 columns
|
||||||
|
// (4,5) -> 1
|
||||||
|
// (3,1) -> 2
|
||||||
|
// (2,3) -> 3
|
||||||
|
float[][] data = new float[][]{
|
||||||
|
new float[]{4f, 5f},
|
||||||
|
new float[]{3f, 1f},
|
||||||
|
new float[]{2f, 3f}
|
||||||
|
};
|
||||||
|
data0 = new BigDenseMatrix(3, 2);
|
||||||
|
for (int i = 0; i < data0.nrow; i++)
|
||||||
|
for (int j = 0; j < data0.ncol; j++)
|
||||||
|
data0.set(i, j, data[i][j]);
|
||||||
|
|
||||||
|
trainMat = new DMatrix(data0);
|
||||||
|
trainMat.setLabel(new float[]{1f, 2f, 3f});
|
||||||
|
|
||||||
|
HashMap<String, Object> params = new HashMap<>();
|
||||||
|
params.put("eta", 1);
|
||||||
|
params.put("max_depth", 5);
|
||||||
|
params.put("silent", 1);
|
||||||
|
params.put("objective", "reg:linear");
|
||||||
|
params.put("seed", 123);
|
||||||
|
|
||||||
|
HashMap<String, DMatrix> watches = new HashMap<>();
|
||||||
|
watches.put("train", trainMat);
|
||||||
|
|
||||||
|
Booster booster = XGBoost.train(trainMat, params, 10, watches, null, null);
|
||||||
|
|
||||||
|
// check overfitting
|
||||||
|
// (4,5) -> 1
|
||||||
|
// (3,1) -> 2
|
||||||
|
// (2,3) -> 3
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
float[][] preds = booster.predict(new DMatrix(data[i], 1, 2));
|
||||||
|
assertEquals(1, preds.length);
|
||||||
|
assertArrayEquals(new float[]{(float) (i + 1)}, preds[0], 1e-2f);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
if (trainMat != null)
|
||||||
|
trainMat.dispose();
|
||||||
|
else if (data0 != null) {
|
||||||
|
data0.dispose();
|
||||||
|
}
|
||||||
|
Rabit.shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private String writeResourceIntoTempFile(String resource) {
|
||||||
|
InputStream input = getClass().getResourceAsStream(resource);
|
||||||
|
if (input == null) {
|
||||||
|
throw new IllegalArgumentException("Resource " + resource + " does not exist.");
|
||||||
|
}
|
||||||
|
File tmp;
|
||||||
|
try {
|
||||||
|
tmp = File.createTempFile("junit", ".test");
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException("Unable to write to temp file.", e);
|
||||||
|
}
|
||||||
|
byte[] buff = new byte[1024];
|
||||||
|
try (FileOutputStream output = new FileOutputStream(tmp)) {
|
||||||
|
int n;
|
||||||
|
while ((n = input.read(buff)) > 0) {
|
||||||
|
output.write(buff, 0, n);
|
||||||
|
}
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException("Unable to write to temp file.", e);
|
||||||
|
}
|
||||||
|
return tmp.getAbsolutePath();
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSetAndGetGroup() throws XGBoostError {
|
public void testSetAndGetGroup() throws XGBoostError {
|
||||||
//create DMatrix from 10*5 dense matrix
|
//create DMatrix from 10*5 dense matrix
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user