diff --git a/jvm-packages/create_jni.py b/jvm-packages/create_jni.py index 82e667889..4d627cb7c 100755 --- a/jvm-packages/create_jni.py +++ b/jvm-packages/create_jni.py @@ -119,3 +119,7 @@ if __name__ == "__main__": cp(file, "xgboost4j-spark/src/test/resources") for file in glob.glob("../demo/data/agaricus.*"): cp(file, "xgboost4j-spark/src/test/resources") + + maybe_makedirs("xgboost4j/src/test/resources") + for file in glob.glob("../demo/data/agaricus.*"): + cp(file, "xgboost4j/src/test/resources") diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java index 69647354c..47a280793 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java @@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.java; import java.util.Iterator; import ml.dmlc.xgboost4j.LabeledPoint; +import ml.dmlc.xgboost4j.java.util.BigDenseMatrix; /** * DMatrix for xgboost. @@ -130,6 +131,16 @@ public class DMatrix { handle = out[0]; } + /** + * create DMatrix from a BigDenseMatrix + * + * @param matrix instance of BigDenseMatrix + * @throws XGBoostError native error + */ + public DMatrix(BigDenseMatrix matrix) throws XGBoostError { + this(matrix, 0.0f); + } + /** * create DMatrix from dense matrix * @param data data values @@ -143,6 +154,18 @@ public class DMatrix { handle = out[0]; } + /** + * create DMatrix from dense matrix + * @param matrix instance of BigDenseMatrix + * @param missing the specified value to represent the missing value + */ + public DMatrix(BigDenseMatrix matrix, float missing) throws XGBoostError { + long[] out = new long[1]; + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromMatRef(matrix.address, matrix.nrow, + matrix.ncol, missing, out)); + handle = out[0]; + } + /** * used for DMatrix slice */ diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index 53965e86d..4eee147e9 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -65,6 +65,9 @@ class XGBoostJNI { public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol, float missing, long[] out); + public final static native int XGDMatrixCreateFromMatRef(long dataRef, int nrow, int ncol, + float missing, long[] out); + public final static native int XGDMatrixSliceDMatrix(long handle, int[] idxset, long[] out); public final static native int XGDMatrixFree(long handle); diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/util/BigDenseMatrix.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/util/BigDenseMatrix.java new file mode 100644 index 000000000..9dbebb544 --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/util/BigDenseMatrix.java @@ -0,0 +1,76 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j.java.util; + +/** + * Off-heap implementation of a Dense Matrix, matrix size is only limited by the + * amount of the available memory and the matrix dimension cannot exceed + * Integer.MAX_VALUE (this is consistent with XGBoost API restrictions on maximum + * length of a response). + */ +public final class BigDenseMatrix { + + private static final int FLOAT_BYTE_SIZE = 4; + public static final long MAX_MATRIX_SIZE = Long.MAX_VALUE / FLOAT_BYTE_SIZE; + + public final int nrow; + public final int ncol; + public final long address; + + public static void setDirect(long valAddress, float val) { + UtilUnsafe.UNSAFE.putFloat(valAddress, val); + } + + public static float getDirect(long valAddress) { + return UtilUnsafe.UNSAFE.getFloat(valAddress); + } + + public BigDenseMatrix(int nrow, int ncol) { + final long size = (long) nrow * ncol; + if (size > MAX_MATRIX_SIZE) { + throw new IllegalArgumentException("Matrix too large; matrix size cannot exceed " + + MAX_MATRIX_SIZE); + } + this.nrow = nrow; + this.ncol = ncol; + this.address = UtilUnsafe.UNSAFE.allocateMemory(size * FLOAT_BYTE_SIZE); + } + + public final void set(long idx, float val) { + setDirect(address + idx * FLOAT_BYTE_SIZE, val); + } + + public final void set(int i, int j, float val) { + set(index(i, j), val); + } + + public final float get(long idx) { + return getDirect(address + idx * FLOAT_BYTE_SIZE); + } + + public final float get(int i, int j) { + return get(index(i, j)); + } + + public final void dispose() { + UtilUnsafe.UNSAFE.freeMemory(address); + } + + private long index(int i, int j) { + return (long) i * ncol + j; + } + +} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/util/UtilUnsafe.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/util/UtilUnsafe.java new file mode 100644 index 000000000..501a9cfe1 --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/util/UtilUnsafe.java @@ -0,0 +1,46 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j.java.util; + +import java.lang.reflect.Field; + +import sun.misc.Unsafe; + +/** + * Simple class to obtain access to the {@link Unsafe} object. Use responsibly :) + */ +public final class UtilUnsafe { + + static Unsafe UNSAFE = getUnsafe(); + + private UtilUnsafe() { + } // dummy private constructor + + private static Unsafe getUnsafe() { + // Not on bootclasspath + if (UtilUnsafe.class.getClassLoader() == null) { + return Unsafe.getUnsafe(); + } + try { + final Field fld = Unsafe.class.getDeclaredField("theUnsafe"); + fld.setAccessible(true); + return (Unsafe) fld.get(UtilUnsafe.class); + } catch (Exception e) { + throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e); + } + } + +} diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index cf6bfbf4d..b6b9a8377 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -247,6 +247,21 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro return ret; } +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGDMatrixCreateFromMatRef + * Signature: (JIIF)J + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromMatRef + (JNIEnv *jenv, jclass jcls, jlong jdataRef, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout) { + DMatrixHandle result; + bst_ulong nrow = (bst_ulong)jnrow; + bst_ulong ncol = (bst_ulong)jncol; + jint ret = (jint) XGDMatrixCreateFromMat((float const *)jdataRef, nrow, ncol, jmiss, &result); + setHandle(jenv, jout, result); + return ret; +} + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index eaefe3925..3d0f0c468 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -55,6 +55,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromMat (JNIEnv *, jclass, jfloatArray, jint, jint, jfloat, jlongArray); +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGDMatrixCreateFromMatRef + * Signature: (JIIF[J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromMatRef + (JNIEnv *jenv, jclass jcls, jlong jdataRef, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout); + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGDMatrixSliceDMatrix diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java index b121bb887..75f52d877 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java @@ -15,13 +15,20 @@ */ package ml.dmlc.xgboost4j.java; +import java.io.*; import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; import java.util.Random; import junit.framework.TestCase; +import ml.dmlc.xgboost4j.java.util.BigDenseMatrix; import ml.dmlc.xgboost4j.LabeledPoint; import org.junit.Test; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + /** * test cases for DMatrix * @@ -53,7 +60,8 @@ public class DMatrixTest { @Test public void testCreateFromFile() throws XGBoostError { //create DMatrix from file - DMatrix dmat = new DMatrix("../../demo/data/agaricus.txt.test"); + String filePath = writeResourceIntoTempFile("/agaricus.txt.test"); + DMatrix dmat = new DMatrix(filePath); //get label float[] labels = dmat.getLabel(); //check length @@ -224,6 +232,122 @@ public class DMatrixTest { TestCase.assertTrue(dmat0.getLabel().length == 10); } + @Test + public void testCreateFromDenseMatrixRef() throws XGBoostError { + //create DMatrix from 10*5 dense matrix + final int nrow = 10; + final int ncol = 5; + + DMatrix dmat0 = null; + BigDenseMatrix data0 = null; + try { + data0 = new BigDenseMatrix(nrow, ncol); + //put random nums + Random random = new Random(); + for (int i = 0; i < nrow * ncol; i++) { + data0.set(i, random.nextFloat()); + } + + //create label + float[] label0 = new float[nrow]; + for (int i = 0; i < nrow; i++) { + label0[i] = random.nextFloat(); + } + + dmat0 = new DMatrix(data0); + dmat0.setLabel(label0); + + //check + TestCase.assertTrue(dmat0.rowNum() == 10); + TestCase.assertTrue(dmat0.getLabel().length == 10); + } finally { + if (dmat0 != null) { + dmat0.dispose(); + } else if (data0 != null){ + data0.dispose(); + } + } + } + + @Test + public void testTrainWithDenseMatrixRef() throws XGBoostError { + Map rabitEnv = new HashMap<>(); + rabitEnv.put("DMLC_TASK_ID", "0"); + Rabit.init(rabitEnv); + DMatrix trainMat = null; + BigDenseMatrix data0 = null; + try { + // trivial dataset with 3 rows and 2 columns + // (4,5) -> 1 + // (3,1) -> 2 + // (2,3) -> 3 + float[][] data = new float[][]{ + new float[]{4f, 5f}, + new float[]{3f, 1f}, + new float[]{2f, 3f} + }; + data0 = new BigDenseMatrix(3, 2); + for (int i = 0; i < data0.nrow; i++) + for (int j = 0; j < data0.ncol; j++) + data0.set(i, j, data[i][j]); + + trainMat = new DMatrix(data0); + trainMat.setLabel(new float[]{1f, 2f, 3f}); + + HashMap params = new HashMap<>(); + params.put("eta", 1); + params.put("max_depth", 5); + params.put("silent", 1); + params.put("objective", "reg:linear"); + params.put("seed", 123); + + HashMap watches = new HashMap<>(); + watches.put("train", trainMat); + + Booster booster = XGBoost.train(trainMat, params, 10, watches, null, null); + + // check overfitting + // (4,5) -> 1 + // (3,1) -> 2 + // (2,3) -> 3 + for (int i = 0; i < 3; i++) { + float[][] preds = booster.predict(new DMatrix(data[i], 1, 2)); + assertEquals(1, preds.length); + assertArrayEquals(new float[]{(float) (i + 1)}, preds[0], 1e-2f); + } + } finally { + if (trainMat != null) + trainMat.dispose(); + else if (data0 != null) { + data0.dispose(); + } + Rabit.shutdown(); + } + } + + private String writeResourceIntoTempFile(String resource) { + InputStream input = getClass().getResourceAsStream(resource); + if (input == null) { + throw new IllegalArgumentException("Resource " + resource + " does not exist."); + } + File tmp; + try { + tmp = File.createTempFile("junit", ".test"); + } catch (IOException e) { + throw new RuntimeException("Unable to write to temp file.", e); + } + byte[] buff = new byte[1024]; + try (FileOutputStream output = new FileOutputStream(tmp)) { + int n; + while ((n = input.read(buff)) > 0) { + output.write(buff, 0, n); + } + } catch (IOException e) { + throw new RuntimeException("Unable to write to temp file.", e); + } + return tmp.getAbsolutePath(); + } + @Test public void testSetAndGetGroup() throws XGBoostError { //create DMatrix from 10*5 dense matrix