[jvm-packages] Robust dmatrix creation (#1613)
* add back train method but mark as deprecated * robust matrix creation in jvm
This commit is contained in:
parent
915ac0b8fe
commit
37bc122c90
@ -92,12 +92,38 @@ public class DMatrix {
|
||||
* @param st Type of sparsity.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
@Deprecated
|
||||
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
|
||||
long[] out = new long[1];
|
||||
if (st == SparseType.CSR) {
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data, 0, out));
|
||||
} else if (st == SparseType.CSC) {
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data, 0, out));
|
||||
} else {
|
||||
throw new UnknownError("unknow sparsetype");
|
||||
}
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Create DMatrix from Sparse matrix in CSR/CSC format.
|
||||
* @param headers The row index of the matrix.
|
||||
* @param indices The indices of presenting entries.
|
||||
* @param data The data content.
|
||||
* @param st Type of sparsity.
|
||||
* @param shapeParam when st is CSR, it specifies the column number, otherwise it is taken as
|
||||
* row number
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st, int shapeParam)
|
||||
throws XGBoostError {
|
||||
long[] out = new long[1];
|
||||
if (st == SparseType.CSR) {
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data,
|
||||
shapeParam, out));
|
||||
} else if (st == SparseType.CSC) {
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data,
|
||||
shapeParam, out));
|
||||
} else {
|
||||
throw new UnknownError("unknow sparsetype");
|
||||
}
|
||||
|
||||
@ -30,11 +30,11 @@ class XGBoostJNI {
|
||||
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
|
||||
String cache_info, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data,
|
||||
long[] out);
|
||||
public final static native int XGDMatrixCreateFromCSREx(long[] indptr, int[] indices, float[] data,
|
||||
int shapeParam, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data,
|
||||
long[] out);
|
||||
public final static native int XGDMatrixCreateFromCSCEx(long[] colptr, int[] indices, float[] data,
|
||||
int shapeParam, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol,
|
||||
float missing, long[] out);
|
||||
|
||||
@ -51,10 +51,27 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
* @param st sparse matrix type (CSR or CSC)
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
@deprecated
|
||||
def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType) {
|
||||
this(new JDMatrix(headers, indices, data, st))
|
||||
}
|
||||
|
||||
/**
|
||||
* create DMatrix from sparse matrix
|
||||
*
|
||||
* @param headers index to headers (rowHeaders for CSR or colHeaders for CSC)
|
||||
* @param indices Indices (colIndexs for CSR or rowIndexs for CSC)
|
||||
* @param data non zero values (sequence by row for CSR or by col for CSC)
|
||||
* @param st sparse matrix type (CSR or CSC)
|
||||
* @param shapeParam when st is CSR, it specifies the column number, otherwise it is taken as
|
||||
* row number
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType,
|
||||
shapeParam: Int) {
|
||||
this(new JDMatrix(headers, indices, data, st, shapeParam))
|
||||
}
|
||||
|
||||
/**
|
||||
* create DMatrix from dense matrix
|
||||
*
|
||||
|
||||
@ -188,18 +188,18 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixCreateFromCSR
|
||||
* Method: XGDMatrixCreateFromCSREx
|
||||
* Signature: ([J[J[F)J
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSR
|
||||
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSREx
|
||||
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jint jcol, jlongArray jout) {
|
||||
DMatrixHandle result;
|
||||
jlong* indptr = jenv->GetLongArrayElements(jindptr, 0);
|
||||
jint* indices = jenv->GetIntArrayElements(jindices, 0);
|
||||
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
|
||||
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
|
||||
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
|
||||
int ret = (jint) XGDMatrixCreateFromCSR((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result);
|
||||
int ret = (jint) XGDMatrixCreateFromCSREx((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jcol, &result);
|
||||
setHandle(jenv, jout, result);
|
||||
//Release
|
||||
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
|
||||
@ -210,11 +210,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixCreateFromCSC
|
||||
* Method: XGDMatrixCreateFromCSCEx
|
||||
* Signature: ([J[J[F)J
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSC
|
||||
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSCEx
|
||||
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jint jrow, jlongArray jout) {
|
||||
DMatrixHandle result;
|
||||
jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL);
|
||||
jint* indices = jenv->GetIntArrayElements(jindices, 0);
|
||||
@ -222,7 +222,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
|
||||
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
|
||||
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
|
||||
|
||||
int ret = (jint) XGDMatrixCreateFromCSC((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result);
|
||||
int ret = (jint) XGDMatrixCreateFromCSCEx((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jrow, &result);
|
||||
setHandle(jenv, jout, result);
|
||||
//release
|
||||
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
|
||||
@ -232,6 +232,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixCreateFromMat
|
||||
|
||||
@ -33,19 +33,19 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixCreateFromCSR
|
||||
* Method: XGDMatrixCreateFromCSREx
|
||||
* Signature: ([J[I[F[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSR
|
||||
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSREx
|
||||
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jint, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixCreateFromCSC
|
||||
* Method: XGDMatrixCreateFromCSCEx
|
||||
* Signature: ([J[I[F[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSC
|
||||
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSCEx
|
||||
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jint, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
|
||||
@ -91,6 +91,78 @@ public class DMatrixTest {
|
||||
TestCase.assertTrue(Arrays.equals(label1, label2));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromCSREx() throws XGBoostError {
|
||||
//create Matrix from csr format sparse Matrix and labels
|
||||
/**
|
||||
* sparse matrix
|
||||
* 1 0 2 3 0
|
||||
* 4 0 2 3 5
|
||||
* 3 1 2 5 0
|
||||
*/
|
||||
float[] data = new float[]{1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5};
|
||||
int[] colIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
|
||||
long[] rowHeaders = new long[]{0, 3, 7, 11};
|
||||
DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR, 5);
|
||||
//check row num
|
||||
TestCase.assertTrue(dmat1.rowNum() == 3);
|
||||
//test set label
|
||||
float[] label1 = new float[]{1, 0, 1};
|
||||
dmat1.setLabel(label1);
|
||||
float[] label2 = dmat1.getLabel();
|
||||
TestCase.assertTrue(Arrays.equals(label1, label2));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromCSC() throws XGBoostError {
|
||||
//create Matrix from csc format sparse Matrix and labels
|
||||
/**
|
||||
* sparse matrix
|
||||
* 1 0 2
|
||||
* 3 0 4
|
||||
* 0 2 3
|
||||
* 5 3 1
|
||||
* 2 5 0
|
||||
*/
|
||||
float[] data = new float[]{1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1};
|
||||
int[] rowIndex = new int[]{0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3};
|
||||
long[] colHeaders = new long[]{0, 4, 7, 11};
|
||||
DMatrix dmat1 = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC);
|
||||
//check row num
|
||||
System.out.println(dmat1.rowNum());
|
||||
TestCase.assertTrue(dmat1.rowNum() == 5);
|
||||
//test set label
|
||||
float[] label1 = new float[]{1, 0, 1, 1, 1};
|
||||
dmat1.setLabel(label1);
|
||||
float[] label2 = dmat1.getLabel();
|
||||
TestCase.assertTrue(Arrays.equals(label1, label2));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromCSCEx() throws XGBoostError {
|
||||
//create Matrix from csc format sparse Matrix and labels
|
||||
/**
|
||||
* sparse matrix
|
||||
* 1 0 2
|
||||
* 3 0 4
|
||||
* 0 2 3
|
||||
* 5 3 1
|
||||
* 2 5 0
|
||||
*/
|
||||
float[] data = new float[]{1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1};
|
||||
int[] rowIndex = new int[]{0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3};
|
||||
long[] colHeaders = new long[]{0, 4, 7, 11};
|
||||
DMatrix dmat1 = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC, 5);
|
||||
//check row num
|
||||
System.out.println(dmat1.rowNum());
|
||||
TestCase.assertTrue(dmat1.rowNum() == 5);
|
||||
//test set label
|
||||
float[] label1 = new float[]{1, 0, 1, 1, 1};
|
||||
dmat1.setLabel(label1);
|
||||
float[] label2 = dmat1.getLabel();
|
||||
TestCase.assertTrue(Arrays.equals(label1, label2));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromDenseMatrix() throws XGBoostError {
|
||||
//create DMatrix from 10*5 dense matrix
|
||||
|
||||
@ -56,6 +56,67 @@ class DMatrixSuite extends FunSuite {
|
||||
assert(label2 === label1)
|
||||
}
|
||||
|
||||
test("create DMatrix from CSREx") {
|
||||
// create Matrix from csr format sparse Matrix and labels
|
||||
/**
|
||||
* sparse matrix
|
||||
* 1 0 2 3 0
|
||||
* 4 0 2 3 5
|
||||
* 3 1 2 5 0
|
||||
*/
|
||||
val data = List[Float](1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5).toArray
|
||||
val colIndex = List(0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3).toArray
|
||||
val rowHeaders = List[Long](0, 3, 7, 11).toArray
|
||||
val dmat1 = new DMatrix(rowHeaders, colIndex, data, JDMatrix.SparseType.CSR, 5)
|
||||
assert(dmat1.rowNum === 3)
|
||||
val label1 = List[Float](1, 0, 1).toArray
|
||||
dmat1.setLabel(label1)
|
||||
val label2 = dmat1.getLabel
|
||||
assert(label2 === label1)
|
||||
}
|
||||
|
||||
test("create DMatrix from CSC") {
|
||||
// create Matrix from csc format sparse Matrix and labels
|
||||
/**
|
||||
* sparse matrix
|
||||
* 1 0 2
|
||||
* 3 0 4
|
||||
* 0 2 3
|
||||
* 5 3 1
|
||||
* 2 5 0
|
||||
*/
|
||||
val data = List[Float](1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1).toArray
|
||||
val rowIndex = List(0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3).toArray
|
||||
val colHeaders = List[Long](0, 4, 7, 11).toArray
|
||||
val dmat1 = new DMatrix(colHeaders, rowIndex, data, JDMatrix.SparseType.CSC)
|
||||
assert(dmat1.rowNum === 5)
|
||||
val label1 = List[Float](1, 0, 1, 1, 1).toArray
|
||||
dmat1.setLabel(label1)
|
||||
val label2 = dmat1.getLabel
|
||||
assert(label2 === label1)
|
||||
}
|
||||
|
||||
test("create DMatrix from CSCEx") {
|
||||
// create Matrix from csc format sparse Matrix and labels
|
||||
/**
|
||||
* sparse matrix
|
||||
* 1 0 2
|
||||
* 3 0 4
|
||||
* 0 2 3
|
||||
* 5 3 1
|
||||
* 2 5 0
|
||||
*/
|
||||
val data = List[Float](1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1).toArray
|
||||
val rowIndex = List(0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3).toArray
|
||||
val colHeaders = List[Long](0, 4, 7, 11).toArray
|
||||
val dmat1 = new DMatrix(colHeaders, rowIndex, data, JDMatrix.SparseType.CSC, 5)
|
||||
assert(dmat1.rowNum === 5)
|
||||
val label1 = List[Float](1, 0, 1, 1, 1).toArray
|
||||
dmat1.setLabel(label1)
|
||||
val label2 = dmat1.getLabel
|
||||
assert(label2 === label1)
|
||||
}
|
||||
|
||||
test("create DMatrix from DenseMatrix") {
|
||||
val nrow = 10
|
||||
val ncol = 5
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user