[jvm-packages] Robust dmatrix creation (#1613)
* add back train method but mark as deprecated * robust matrix creation in jvm
This commit is contained in:
parent
915ac0b8fe
commit
37bc122c90
@ -92,12 +92,38 @@ public class DMatrix {
|
|||||||
* @param st Type of sparsity.
|
* @param st Type of sparsity.
|
||||||
* @throws XGBoostError
|
* @throws XGBoostError
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
|
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
|
||||||
long[] out = new long[1];
|
long[] out = new long[1];
|
||||||
if (st == SparseType.CSR) {
|
if (st == SparseType.CSR) {
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out));
|
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data, 0, out));
|
||||||
} else if (st == SparseType.CSC) {
|
} else if (st == SparseType.CSC) {
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out));
|
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data, 0, out));
|
||||||
|
} else {
|
||||||
|
throw new UnknownError("unknow sparsetype");
|
||||||
|
}
|
||||||
|
handle = out[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create DMatrix from Sparse matrix in CSR/CSC format.
|
||||||
|
* @param headers The row index of the matrix.
|
||||||
|
* @param indices The indices of presenting entries.
|
||||||
|
* @param data The data content.
|
||||||
|
* @param st Type of sparsity.
|
||||||
|
* @param shapeParam when st is CSR, it specifies the column number, otherwise it is taken as
|
||||||
|
* row number
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st, int shapeParam)
|
||||||
|
throws XGBoostError {
|
||||||
|
long[] out = new long[1];
|
||||||
|
if (st == SparseType.CSR) {
|
||||||
|
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data,
|
||||||
|
shapeParam, out));
|
||||||
|
} else if (st == SparseType.CSC) {
|
||||||
|
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data,
|
||||||
|
shapeParam, out));
|
||||||
} else {
|
} else {
|
||||||
throw new UnknownError("unknow sparsetype");
|
throw new UnknownError("unknow sparsetype");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -30,11 +30,11 @@ class XGBoostJNI {
|
|||||||
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
|
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
|
||||||
String cache_info, long[] out);
|
String cache_info, long[] out);
|
||||||
|
|
||||||
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data,
|
public final static native int XGDMatrixCreateFromCSREx(long[] indptr, int[] indices, float[] data,
|
||||||
long[] out);
|
int shapeParam, long[] out);
|
||||||
|
|
||||||
public final static native int XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data,
|
public final static native int XGDMatrixCreateFromCSCEx(long[] colptr, int[] indices, float[] data,
|
||||||
long[] out);
|
int shapeParam, long[] out);
|
||||||
|
|
||||||
public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol,
|
public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol,
|
||||||
float missing, long[] out);
|
float missing, long[] out);
|
||||||
|
|||||||
@ -51,10 +51,27 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
|||||||
* @param st sparse matrix type (CSR or CSC)
|
* @param st sparse matrix type (CSR or CSC)
|
||||||
*/
|
*/
|
||||||
@throws(classOf[XGBoostError])
|
@throws(classOf[XGBoostError])
|
||||||
|
@deprecated
|
||||||
def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType) {
|
def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType) {
|
||||||
this(new JDMatrix(headers, indices, data, st))
|
this(new JDMatrix(headers, indices, data, st))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* create DMatrix from sparse matrix
|
||||||
|
*
|
||||||
|
* @param headers index to headers (rowHeaders for CSR or colHeaders for CSC)
|
||||||
|
* @param indices Indices (colIndexs for CSR or rowIndexs for CSC)
|
||||||
|
* @param data non zero values (sequence by row for CSR or by col for CSC)
|
||||||
|
* @param st sparse matrix type (CSR or CSC)
|
||||||
|
* @param shapeParam when st is CSR, it specifies the column number, otherwise it is taken as
|
||||||
|
* row number
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType,
|
||||||
|
shapeParam: Int) {
|
||||||
|
this(new JDMatrix(headers, indices, data, st, shapeParam))
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* create DMatrix from dense matrix
|
* create DMatrix from dense matrix
|
||||||
*
|
*
|
||||||
|
|||||||
@ -188,18 +188,18 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
|
|||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: XGDMatrixCreateFromCSR
|
* Method: XGDMatrixCreateFromCSREx
|
||||||
* Signature: ([J[J[F)J
|
* Signature: ([J[J[F)J
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSR
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSREx
|
||||||
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
|
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jint jcol, jlongArray jout) {
|
||||||
DMatrixHandle result;
|
DMatrixHandle result;
|
||||||
jlong* indptr = jenv->GetLongArrayElements(jindptr, 0);
|
jlong* indptr = jenv->GetLongArrayElements(jindptr, 0);
|
||||||
jint* indices = jenv->GetIntArrayElements(jindices, 0);
|
jint* indices = jenv->GetIntArrayElements(jindices, 0);
|
||||||
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
|
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
|
||||||
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
|
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
|
||||||
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
|
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
|
||||||
int ret = (jint) XGDMatrixCreateFromCSR((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result);
|
int ret = (jint) XGDMatrixCreateFromCSREx((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jcol, &result);
|
||||||
setHandle(jenv, jout, result);
|
setHandle(jenv, jout, result);
|
||||||
//Release
|
//Release
|
||||||
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
|
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
|
||||||
@ -210,11 +210,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
|
|||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: XGDMatrixCreateFromCSC
|
* Method: XGDMatrixCreateFromCSCEx
|
||||||
* Signature: ([J[J[F)J
|
* Signature: ([J[J[F)J
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSC
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSCEx
|
||||||
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
|
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jint jrow, jlongArray jout) {
|
||||||
DMatrixHandle result;
|
DMatrixHandle result;
|
||||||
jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL);
|
jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL);
|
||||||
jint* indices = jenv->GetIntArrayElements(jindices, 0);
|
jint* indices = jenv->GetIntArrayElements(jindices, 0);
|
||||||
@ -222,7 +222,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
|
|||||||
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
|
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
|
||||||
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
|
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
|
||||||
|
|
||||||
int ret = (jint) XGDMatrixCreateFromCSC((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result);
|
int ret = (jint) XGDMatrixCreateFromCSCEx((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jrow, &result);
|
||||||
setHandle(jenv, jout, result);
|
setHandle(jenv, jout, result);
|
||||||
//release
|
//release
|
||||||
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
|
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
|
||||||
@ -232,6 +232,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: XGDMatrixCreateFromMat
|
* Method: XGDMatrixCreateFromMat
|
||||||
|
|||||||
@ -33,19 +33,19 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
|
|||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: XGDMatrixCreateFromCSR
|
* Method: XGDMatrixCreateFromCSREx
|
||||||
* Signature: ([J[I[F[J)I
|
* Signature: ([J[I[F[J)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSR
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSREx
|
||||||
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
|
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jint, jlongArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: XGDMatrixCreateFromCSC
|
* Method: XGDMatrixCreateFromCSCEx
|
||||||
* Signature: ([J[I[F[J)I
|
* Signature: ([J[I[F[J)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSC
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSCEx
|
||||||
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
|
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jint, jlongArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
|||||||
@ -91,6 +91,78 @@ public class DMatrixTest {
|
|||||||
TestCase.assertTrue(Arrays.equals(label1, label2));
|
TestCase.assertTrue(Arrays.equals(label1, label2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCreateFromCSREx() throws XGBoostError {
|
||||||
|
//create Matrix from csr format sparse Matrix and labels
|
||||||
|
/**
|
||||||
|
* sparse matrix
|
||||||
|
* 1 0 2 3 0
|
||||||
|
* 4 0 2 3 5
|
||||||
|
* 3 1 2 5 0
|
||||||
|
*/
|
||||||
|
float[] data = new float[]{1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5};
|
||||||
|
int[] colIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
|
||||||
|
long[] rowHeaders = new long[]{0, 3, 7, 11};
|
||||||
|
DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR, 5);
|
||||||
|
//check row num
|
||||||
|
TestCase.assertTrue(dmat1.rowNum() == 3);
|
||||||
|
//test set label
|
||||||
|
float[] label1 = new float[]{1, 0, 1};
|
||||||
|
dmat1.setLabel(label1);
|
||||||
|
float[] label2 = dmat1.getLabel();
|
||||||
|
TestCase.assertTrue(Arrays.equals(label1, label2));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCreateFromCSC() throws XGBoostError {
|
||||||
|
//create Matrix from csc format sparse Matrix and labels
|
||||||
|
/**
|
||||||
|
* sparse matrix
|
||||||
|
* 1 0 2
|
||||||
|
* 3 0 4
|
||||||
|
* 0 2 3
|
||||||
|
* 5 3 1
|
||||||
|
* 2 5 0
|
||||||
|
*/
|
||||||
|
float[] data = new float[]{1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1};
|
||||||
|
int[] rowIndex = new int[]{0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3};
|
||||||
|
long[] colHeaders = new long[]{0, 4, 7, 11};
|
||||||
|
DMatrix dmat1 = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC);
|
||||||
|
//check row num
|
||||||
|
System.out.println(dmat1.rowNum());
|
||||||
|
TestCase.assertTrue(dmat1.rowNum() == 5);
|
||||||
|
//test set label
|
||||||
|
float[] label1 = new float[]{1, 0, 1, 1, 1};
|
||||||
|
dmat1.setLabel(label1);
|
||||||
|
float[] label2 = dmat1.getLabel();
|
||||||
|
TestCase.assertTrue(Arrays.equals(label1, label2));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCreateFromCSCEx() throws XGBoostError {
|
||||||
|
//create Matrix from csc format sparse Matrix and labels
|
||||||
|
/**
|
||||||
|
* sparse matrix
|
||||||
|
* 1 0 2
|
||||||
|
* 3 0 4
|
||||||
|
* 0 2 3
|
||||||
|
* 5 3 1
|
||||||
|
* 2 5 0
|
||||||
|
*/
|
||||||
|
float[] data = new float[]{1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1};
|
||||||
|
int[] rowIndex = new int[]{0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3};
|
||||||
|
long[] colHeaders = new long[]{0, 4, 7, 11};
|
||||||
|
DMatrix dmat1 = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC, 5);
|
||||||
|
//check row num
|
||||||
|
System.out.println(dmat1.rowNum());
|
||||||
|
TestCase.assertTrue(dmat1.rowNum() == 5);
|
||||||
|
//test set label
|
||||||
|
float[] label1 = new float[]{1, 0, 1, 1, 1};
|
||||||
|
dmat1.setLabel(label1);
|
||||||
|
float[] label2 = dmat1.getLabel();
|
||||||
|
TestCase.assertTrue(Arrays.equals(label1, label2));
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCreateFromDenseMatrix() throws XGBoostError {
|
public void testCreateFromDenseMatrix() throws XGBoostError {
|
||||||
//create DMatrix from 10*5 dense matrix
|
//create DMatrix from 10*5 dense matrix
|
||||||
|
|||||||
@ -56,6 +56,67 @@ class DMatrixSuite extends FunSuite {
|
|||||||
assert(label2 === label1)
|
assert(label2 === label1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("create DMatrix from CSREx") {
|
||||||
|
// create Matrix from csr format sparse Matrix and labels
|
||||||
|
/**
|
||||||
|
* sparse matrix
|
||||||
|
* 1 0 2 3 0
|
||||||
|
* 4 0 2 3 5
|
||||||
|
* 3 1 2 5 0
|
||||||
|
*/
|
||||||
|
val data = List[Float](1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5).toArray
|
||||||
|
val colIndex = List(0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3).toArray
|
||||||
|
val rowHeaders = List[Long](0, 3, 7, 11).toArray
|
||||||
|
val dmat1 = new DMatrix(rowHeaders, colIndex, data, JDMatrix.SparseType.CSR, 5)
|
||||||
|
assert(dmat1.rowNum === 3)
|
||||||
|
val label1 = List[Float](1, 0, 1).toArray
|
||||||
|
dmat1.setLabel(label1)
|
||||||
|
val label2 = dmat1.getLabel
|
||||||
|
assert(label2 === label1)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("create DMatrix from CSC") {
|
||||||
|
// create Matrix from csc format sparse Matrix and labels
|
||||||
|
/**
|
||||||
|
* sparse matrix
|
||||||
|
* 1 0 2
|
||||||
|
* 3 0 4
|
||||||
|
* 0 2 3
|
||||||
|
* 5 3 1
|
||||||
|
* 2 5 0
|
||||||
|
*/
|
||||||
|
val data = List[Float](1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1).toArray
|
||||||
|
val rowIndex = List(0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3).toArray
|
||||||
|
val colHeaders = List[Long](0, 4, 7, 11).toArray
|
||||||
|
val dmat1 = new DMatrix(colHeaders, rowIndex, data, JDMatrix.SparseType.CSC)
|
||||||
|
assert(dmat1.rowNum === 5)
|
||||||
|
val label1 = List[Float](1, 0, 1, 1, 1).toArray
|
||||||
|
dmat1.setLabel(label1)
|
||||||
|
val label2 = dmat1.getLabel
|
||||||
|
assert(label2 === label1)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("create DMatrix from CSCEx") {
|
||||||
|
// create Matrix from csc format sparse Matrix and labels
|
||||||
|
/**
|
||||||
|
* sparse matrix
|
||||||
|
* 1 0 2
|
||||||
|
* 3 0 4
|
||||||
|
* 0 2 3
|
||||||
|
* 5 3 1
|
||||||
|
* 2 5 0
|
||||||
|
*/
|
||||||
|
val data = List[Float](1, 3, 5, 2, 2, 3, 5, 2, 4, 3, 1).toArray
|
||||||
|
val rowIndex = List(0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3).toArray
|
||||||
|
val colHeaders = List[Long](0, 4, 7, 11).toArray
|
||||||
|
val dmat1 = new DMatrix(colHeaders, rowIndex, data, JDMatrix.SparseType.CSC, 5)
|
||||||
|
assert(dmat1.rowNum === 5)
|
||||||
|
val label1 = List[Float](1, 0, 1, 1, 1).toArray
|
||||||
|
dmat1.setLabel(label1)
|
||||||
|
val label2 = dmat1.getLabel
|
||||||
|
assert(label2 === label1)
|
||||||
|
}
|
||||||
|
|
||||||
test("create DMatrix from DenseMatrix") {
|
test("create DMatrix from DenseMatrix") {
|
||||||
val nrow = 10
|
val nrow = 10
|
||||||
val ncol = 5
|
val ncol = 5
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user