[jvm-packages] Robust dmatrix creation (#1613)

* add back train method but mark as deprecated

* robust matrix creation in jvm
This commit is contained in:
Nan Zhu 2016-09-26 13:35:04 -04:00 committed by GitHub
parent 915ac0b8fe
commit 37bc122c90
7 changed files with 197 additions and 20 deletions

View File

@ -92,12 +92,38 @@ public class DMatrix {
* @param st Type of sparsity.
* @throws XGBoostError
*/
@Deprecated
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
long[] out = new long[1];
if (st == SparseType.CSR) {
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out));
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data, 0, out));
} else if (st == SparseType.CSC) {
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out));
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data, 0, out));
} else {
throw new UnknownError("unknow sparsetype");
}
handle = out[0];
}
/**
* Create DMatrix from Sparse matrix in CSR/CSC format.
* @param headers The row index of the matrix.
* @param indices The indices of presenting entries.
* @param data The data content.
* @param st Type of sparsity.
* @param shapeParam when st is CSR, it specifies the column number, otherwise it is taken as
* row number
* @throws XGBoostError
*/
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st, int shapeParam)
throws XGBoostError {
long[] out = new long[1];
if (st == SparseType.CSR) {
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data,
shapeParam, out));
} else if (st == SparseType.CSC) {
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data,
shapeParam, out));
} else {
throw new UnknownError("unknow sparsetype");
}

View File

@ -30,11 +30,11 @@ class XGBoostJNI {
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
String cache_info, long[] out);
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data,
long[] out);
public final static native int XGDMatrixCreateFromCSREx(long[] indptr, int[] indices, float[] data,
int shapeParam, long[] out);
public final static native int XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data,
long[] out);
public final static native int XGDMatrixCreateFromCSCEx(long[] colptr, int[] indices, float[] data,
int shapeParam, long[] out);
public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol,
float missing, long[] out);

View File

@ -51,10 +51,27 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
* @param st sparse matrix type (CSR or CSC)
*/
@throws(classOf[XGBoostError])
@deprecated
def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType) {
this(new JDMatrix(headers, indices, data, st))
}
/**
* create DMatrix from sparse matrix
*
* @param headers index to headers (rowHeaders for CSR or colHeaders for CSC)
* @param indices Indices (colIndexs for CSR or rowIndexs for CSC)
* @param data non zero values (sequence by row for CSR or by col for CSC)
* @param st sparse matrix type (CSR or CSC)
* @param shapeParam when st is CSR, it specifies the column number, otherwise it is taken as
* row number
*/
@throws(classOf[XGBoostError])
def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType,
shapeParam: Int) {
this(new JDMatrix(headers, indices, data, st, shapeParam))
}
/**
* create DMatrix from dense matrix
*

View File

@ -188,18 +188,18 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromCSR
* Method: XGDMatrixCreateFromCSREx
* Signature: ([J[J[F)J
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSR
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSREx
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jint jcol, jlongArray jout) {
DMatrixHandle result;
jlong* indptr = jenv->GetLongArrayElements(jindptr, 0);
jint* indices = jenv->GetIntArrayElements(jindices, 0);
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
int ret = (jint) XGDMatrixCreateFromCSR((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result);
int ret = (jint) XGDMatrixCreateFromCSREx((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jcol, &result);
setHandle(jenv, jout, result);
//Release
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
@ -210,11 +210,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromCSC
* Method: XGDMatrixCreateFromCSCEx
* Signature: ([J[J[F)J
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSC
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSCEx
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jint jrow, jlongArray jout) {
DMatrixHandle result;
jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL);
jint* indices = jenv->GetIntArrayElements(jindices, 0);
@ -222,7 +222,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
int ret = (jint) XGDMatrixCreateFromCSC((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result);
int ret = (jint) XGDMatrixCreateFromCSCEx((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jrow, &result);
setHandle(jenv, jout, result);
//release
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
@ -232,6 +232,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromMat

View File

@ -33,19 +33,19 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromCSR
* Method: XGDMatrixCreateFromCSREx
* Signature: ([J[I[F[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSR
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSREx
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jint, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromCSC
* Method: XGDMatrixCreateFromCSCEx
* Signature: ([J[I[F[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSC
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSCEx
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jint, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI

View File

@ -91,6 +91,78 @@ public class DMatrixTest {
TestCase.assertTrue(Arrays.equals(label1, label2));
}
@Test
public void testCreateFromCSREx() throws XGBoostError {
//create Matrix from csr format sparse Matrix and labels
/**
* sparse matrix
* 1 0 2 3 0
* 4 0 2 3 5
* 3 1 2 5 0
*/
float[] data = new float[]{1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5};
int[] colIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
long[] rowHeaders = new long[]{0, 3, 7, 11};
DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR, 5);
//check row num
TestCase.assertTrue(dmat1.rowNum() == 3);
//test set label
float[] label1 = new float[]{1, 0, 1};
dmat1.setLabel(label1);
float[] label2 = dmat1.getLabel();
TestCase.assertTrue(Arrays.equals(label1, label2));
}
@Test
public void testCreateFromCSC() throws XGBoostError {
//create Matrix from csc format sparse Matrix and labels
/**
* sparse matrix
* 1 0 2
* 3 0 4
* 0 2 3
* 5 3 1
* 2 5 0
*/
float[] data = new float[]{1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1};
int[] rowIndex = new int[]{0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3};
long[] colHeaders = new long[]{0, 4, 7, 11};
DMatrix dmat1 = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC);
//check row num
System.out.println(dmat1.rowNum());
TestCase.assertTrue(dmat1.rowNum() == 5);
//test set label
float[] label1 = new float[]{1, 0, 1, 1, 1};
dmat1.setLabel(label1);
float[] label2 = dmat1.getLabel();
TestCase.assertTrue(Arrays.equals(label1, label2));
}
@Test
public void testCreateFromCSCEx() throws XGBoostError {
//create Matrix from csc format sparse Matrix and labels
/**
* sparse matrix
* 1 0 2
* 3 0 4
* 0 2 3
* 5 3 1
* 2 5 0
*/
float[] data = new float[]{1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1};
int[] rowIndex = new int[]{0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3};
long[] colHeaders = new long[]{0, 4, 7, 11};
DMatrix dmat1 = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC, 5);
//check row num
System.out.println(dmat1.rowNum());
TestCase.assertTrue(dmat1.rowNum() == 5);
//test set label
float[] label1 = new float[]{1, 0, 1, 1, 1};
dmat1.setLabel(label1);
float[] label2 = dmat1.getLabel();
TestCase.assertTrue(Arrays.equals(label1, label2));
}
@Test
public void testCreateFromDenseMatrix() throws XGBoostError {
//create DMatrix from 10*5 dense matrix

View File

@ -56,6 +56,67 @@ class DMatrixSuite extends FunSuite {
assert(label2 === label1)
}
test("create DMatrix from CSREx") {
// create Matrix from csr format sparse Matrix and labels
/**
* sparse matrix
* 1 0 2 3 0
* 4 0 2 3 5
* 3 1 2 5 0
*/
val data = List[Float](1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5).toArray
val colIndex = List(0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3).toArray
val rowHeaders = List[Long](0, 3, 7, 11).toArray
val dmat1 = new DMatrix(rowHeaders, colIndex, data, JDMatrix.SparseType.CSR, 5)
assert(dmat1.rowNum === 3)
val label1 = List[Float](1, 0, 1).toArray
dmat1.setLabel(label1)
val label2 = dmat1.getLabel
assert(label2 === label1)
}
test("create DMatrix from CSC") {
// create Matrix from csc format sparse Matrix and labels
/**
* sparse matrix
* 1 0 2
* 3 0 4
* 0 2 3
* 5 3 1
* 2 5 0
*/
val data = List[Float](1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1).toArray
val rowIndex = List(0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3).toArray
val colHeaders = List[Long](0, 4, 7, 11).toArray
val dmat1 = new DMatrix(colHeaders, rowIndex, data, JDMatrix.SparseType.CSC)
assert(dmat1.rowNum === 5)
val label1 = List[Float](1, 0, 1, 1, 1).toArray
dmat1.setLabel(label1)
val label2 = dmat1.getLabel
assert(label2 === label1)
}
test("create DMatrix from CSCEx") {
// create Matrix from csc format sparse Matrix and labels
/**
* sparse matrix
* 1 0 2
* 3 0 4
* 0 2 3
* 5 3 1
* 2 5 0
*/
val data = List[Float](1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1).toArray
val rowIndex = List(0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3).toArray
val colHeaders = List[Long](0, 4, 7, 11).toArray
val dmat1 = new DMatrix(colHeaders, rowIndex, data, JDMatrix.SparseType.CSC, 5)
assert(dmat1.rowNum === 5)
val label1 = List[Float](1, 0, 1, 1, 1).toArray
dmat1.setLabel(label1)
val label2 = dmat1.getLabel
assert(label2 === label1)
}
test("create DMatrix from DenseMatrix") {
val nrow = 10
val ncol = 5