From 37bc122c90b7c1337403f797fb7b01fc6abd5b36 Mon Sep 17 00:00:00 2001 From: Nan Zhu Date: Mon, 26 Sep 2016 13:35:04 -0400 Subject: [PATCH] [jvm-packages] Robust dmatrix creation (#1613) * add back train method but mark as deprecated * robust matrix creation in jvm --- .../java/ml/dmlc/xgboost4j/java/DMatrix.java | 30 +++++++- .../ml/dmlc/xgboost4j/java/XGBoostJNI.java | 8 +-- .../ml/dmlc/xgboost4j/scala/DMatrix.scala | 17 +++++ .../xgboost4j/src/native/xgboost4j.cpp | 17 ++--- jvm-packages/xgboost4j/src/native/xgboost4j.h | 12 ++-- .../ml/dmlc/xgboost4j/java/DMatrixTest.java | 72 +++++++++++++++++++ .../dmlc/xgboost4j/scala/DMatrixSuite.scala | 61 ++++++++++++++++ 7 files changed, 197 insertions(+), 20 deletions(-) diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java index 8d4db66a1..b2b55597c 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java @@ -92,12 +92,38 @@ public class DMatrix { * @param st Type of sparsity. * @throws XGBoostError */ + @Deprecated public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError { long[] out = new long[1]; if (st == SparseType.CSR) { - JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out)); + JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data, 0, out)); } else if (st == SparseType.CSC) { - JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out)); + JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data, 0, out)); + } else { + throw new UnknownError("unknow sparsetype"); + } + handle = out[0]; + } + + /** + * Create DMatrix from Sparse matrix in CSR/CSC format. + * @param headers The row index of the matrix. + * @param indices The indices of presenting entries. + * @param data The data content. + * @param st Type of sparsity. + * @param shapeParam when st is CSR, it specifies the column number, otherwise it is taken as + * row number + * @throws XGBoostError + */ + public DMatrix(long[] headers, int[] indices, float[] data, SparseType st, int shapeParam) + throws XGBoostError { + long[] out = new long[1]; + if (st == SparseType.CSR) { + JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data, + shapeParam, out)); + } else if (st == SparseType.CSC) { + JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data, + shapeParam, out)); } else { throw new UnknownError("unknow sparsetype"); } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index 29ea29163..4ecef65a7 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -30,11 +30,11 @@ class XGBoostJNI { final static native int XGDMatrixCreateFromDataIter(java.util.Iterator iter, String cache_info, long[] out); - public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data, - long[] out); + public final static native int XGDMatrixCreateFromCSREx(long[] indptr, int[] indices, float[] data, + int shapeParam, long[] out); - public final static native int XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data, - long[] out); + public final static native int XGDMatrixCreateFromCSCEx(long[] colptr, int[] indices, float[] data, + int shapeParam, long[] out); public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol, float missing, long[] out); diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala index bb21c3004..bf2952ec5 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala @@ -51,10 +51,27 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { * @param st sparse matrix type (CSR or CSC) */ @throws(classOf[XGBoostError]) + @deprecated def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType) { this(new JDMatrix(headers, indices, data, st)) } + /** + * create DMatrix from sparse matrix + * + * @param headers index to headers (rowHeaders for CSR or colHeaders for CSC) + * @param indices Indices (colIndexs for CSR or rowIndexs for CSC) + * @param data non zero values (sequence by row for CSR or by col for CSC) + * @param st sparse matrix type (CSR or CSC) + * @param shapeParam when st is CSR, it specifies the column number, otherwise it is taken as + * row number + */ + @throws(classOf[XGBoostError]) + def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType, + shapeParam: Int) { + this(new JDMatrix(headers, indices, data, st, shapeParam)) + } + /** * create DMatrix from dense matrix * diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index 9475bb248..9df04e54e 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -188,18 +188,18 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: XGDMatrixCreateFromCSR + * Method: XGDMatrixCreateFromCSREx * Signature: ([J[J[F)J */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSR - (JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) { +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSREx + (JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jint jcol, jlongArray jout) { DMatrixHandle result; jlong* indptr = jenv->GetLongArrayElements(jindptr, 0); jint* indices = jenv->GetIntArrayElements(jindices, 0); jfloat* data = jenv->GetFloatArrayElements(jdata, 0); bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr); bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata); - int ret = (jint) XGDMatrixCreateFromCSR((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result); + int ret = (jint) XGDMatrixCreateFromCSREx((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jcol, &result); setHandle(jenv, jout, result); //Release jenv->ReleaseLongArrayElements(jindptr, indptr, 0); @@ -210,11 +210,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: XGDMatrixCreateFromCSC + * Method: XGDMatrixCreateFromCSCEx * Signature: ([J[J[F)J */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSC - (JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) { +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSCEx + (JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jint jrow, jlongArray jout) { DMatrixHandle result; jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL); jint* indices = jenv->GetIntArrayElements(jindices, 0); @@ -222,7 +222,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr); bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata); - int ret = (jint) XGDMatrixCreateFromCSC((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result); + int ret = (jint) XGDMatrixCreateFromCSCEx((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jrow, &result); setHandle(jenv, jout, result); //release jenv->ReleaseLongArrayElements(jindptr, indptr, 0); @@ -232,6 +232,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro return ret; } + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGDMatrixCreateFromMat diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index eb2e0244a..15410abed 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -33,19 +33,19 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: XGDMatrixCreateFromCSR + * Method: XGDMatrixCreateFromCSREx * Signature: ([J[I[F[J)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSR - (JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray); +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSREx + (JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jint, jlongArray); /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: XGDMatrixCreateFromCSC + * Method: XGDMatrixCreateFromCSCEx * Signature: ([J[I[F[J)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSC - (JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray); +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSCEx + (JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jint, jlongArray); /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java index 0889f5777..e2bf5c7ab 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java @@ -91,6 +91,78 @@ public class DMatrixTest { TestCase.assertTrue(Arrays.equals(label1, label2)); } + @Test + public void testCreateFromCSREx() throws XGBoostError { + //create Matrix from csr format sparse Matrix and labels + /** + * sparse matrix + * 1 0 2 3 0 + * 4 0 2 3 5 + * 3 1 2 5 0 + */ + float[] data = new float[]{1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5}; + int[] colIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3}; + long[] rowHeaders = new long[]{0, 3, 7, 11}; + DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR, 5); + //check row num + TestCase.assertTrue(dmat1.rowNum() == 3); + //test set label + float[] label1 = new float[]{1, 0, 1}; + dmat1.setLabel(label1); + float[] label2 = dmat1.getLabel(); + TestCase.assertTrue(Arrays.equals(label1, label2)); + } + + @Test + public void testCreateFromCSC() throws XGBoostError { + //create Matrix from csc format sparse Matrix and labels + /** + * sparse matrix + * 1 0 2 + * 3 0 4 + * 0 2 3 + * 5 3 1 + * 2 5 0 + */ + float[] data = new float[]{1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1}; + int[] rowIndex = new int[]{0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3}; + long[] colHeaders = new long[]{0, 4, 7, 11}; + DMatrix dmat1 = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC); + //check row num + System.out.println(dmat1.rowNum()); + TestCase.assertTrue(dmat1.rowNum() == 5); + //test set label + float[] label1 = new float[]{1, 0, 1, 1, 1}; + dmat1.setLabel(label1); + float[] label2 = dmat1.getLabel(); + TestCase.assertTrue(Arrays.equals(label1, label2)); + } + + @Test + public void testCreateFromCSCEx() throws XGBoostError { + //create Matrix from csc format sparse Matrix and labels + /** + * sparse matrix + * 1 0 2 + * 3 0 4 + * 0 2 3 + * 5 3 1 + * 2 5 0 + */ + float[] data = new float[]{1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1}; + int[] rowIndex = new int[]{0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3}; + long[] colHeaders = new long[]{0, 4, 7, 11}; + DMatrix dmat1 = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC, 5); + //check row num + System.out.println(dmat1.rowNum()); + TestCase.assertTrue(dmat1.rowNum() == 5); + //test set label + float[] label1 = new float[]{1, 0, 1, 1, 1}; + dmat1.setLabel(label1); + float[] label2 = dmat1.getLabel(); + TestCase.assertTrue(Arrays.equals(label1, label2)); + } + @Test public void testCreateFromDenseMatrix() throws XGBoostError { //create DMatrix from 10*5 dense matrix diff --git a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/DMatrixSuite.scala b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/DMatrixSuite.scala index 0080b1dd4..87ff8e006 100644 --- a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/DMatrixSuite.scala +++ b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/DMatrixSuite.scala @@ -56,6 +56,67 @@ class DMatrixSuite extends FunSuite { assert(label2 === label1) } + test("create DMatrix from CSREx") { + // create Matrix from csr format sparse Matrix and labels + /** + * sparse matrix + * 1 0 2 3 0 + * 4 0 2 3 5 + * 3 1 2 5 0 + */ + val data = List[Float](1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5).toArray + val colIndex = List(0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3).toArray + val rowHeaders = List[Long](0, 3, 7, 11).toArray + val dmat1 = new DMatrix(rowHeaders, colIndex, data, JDMatrix.SparseType.CSR, 5) + assert(dmat1.rowNum === 3) + val label1 = List[Float](1, 0, 1).toArray + dmat1.setLabel(label1) + val label2 = dmat1.getLabel + assert(label2 === label1) + } + + test("create DMatrix from CSC") { + // create Matrix from csc format sparse Matrix and labels + /** + * sparse matrix + * 1 0 2 + * 3 0 4 + * 0 2 3 + * 5 3 1 + * 2 5 0 + */ + val data = List[Float](1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1).toArray + val rowIndex = List(0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3).toArray + val colHeaders = List[Long](0, 4, 7, 11).toArray + val dmat1 = new DMatrix(colHeaders, rowIndex, data, JDMatrix.SparseType.CSC) + assert(dmat1.rowNum === 5) + val label1 = List[Float](1, 0, 1, 1, 1).toArray + dmat1.setLabel(label1) + val label2 = dmat1.getLabel + assert(label2 === label1) + } + + test("create DMatrix from CSCEx") { + // create Matrix from csc format sparse Matrix and labels + /** + * sparse matrix + * 1 0 2 + * 3 0 4 + * 0 2 3 + * 5 3 1 + * 2 5 0 + */ + val data = List[Float](1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1).toArray + val rowIndex = List(0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3).toArray + val colHeaders = List[Long](0, 4, 7, 11).toArray + val dmat1 = new DMatrix(colHeaders, rowIndex, data, JDMatrix.SparseType.CSC, 5) + assert(dmat1.rowNum === 5) + val label1 = List[Float](1, 0, 1, 1, 1).toArray + dmat1.setLabel(label1) + val label2 = dmat1.getLabel + assert(label2 === label1) + } + test("create DMatrix from DenseMatrix") { val nrow = 10 val ncol = 5