[jvm-packages] Robust dmatrix creation (#1613)

* add back train method but mark as deprecated

* robust matrix creation in jvm
This commit is contained in:
Nan Zhu
2016-09-26 13:35:04 -04:00
committed by GitHub
parent 915ac0b8fe
commit 37bc122c90
7 changed files with 197 additions and 20 deletions

View File

@@ -92,12 +92,38 @@ public class DMatrix {
* @param st Type of sparsity.
* @throws XGBoostError
*/
@Deprecated
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
long[] out = new long[1];
if (st == SparseType.CSR) {
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out));
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data, 0, out));
} else if (st == SparseType.CSC) {
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out));
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data, 0, out));
} else {
throw new UnknownError("unknow sparsetype");
}
handle = out[0];
}
/**
* Create DMatrix from Sparse matrix in CSR/CSC format.
* @param headers The row index of the matrix.
* @param indices The indices of presenting entries.
* @param data The data content.
* @param st Type of sparsity.
* @param shapeParam when st is CSR, it specifies the column number, otherwise it is taken as
* row number
* @throws XGBoostError
*/
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st, int shapeParam)
throws XGBoostError {
long[] out = new long[1];
if (st == SparseType.CSR) {
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data,
shapeParam, out));
} else if (st == SparseType.CSC) {
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data,
shapeParam, out));
} else {
throw new UnknownError("unknow sparsetype");
}

View File

@@ -30,11 +30,11 @@ class XGBoostJNI {
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
String cache_info, long[] out);
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data,
long[] out);
public final static native int XGDMatrixCreateFromCSREx(long[] indptr, int[] indices, float[] data,
int shapeParam, long[] out);
public final static native int XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data,
long[] out);
public final static native int XGDMatrixCreateFromCSCEx(long[] colptr, int[] indices, float[] data,
int shapeParam, long[] out);
public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol,
float missing, long[] out);

View File

@@ -51,10 +51,27 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
* @param st sparse matrix type (CSR or CSC)
*/
@throws(classOf[XGBoostError])
@deprecated
def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType) {
this(new JDMatrix(headers, indices, data, st))
}
/**
* create DMatrix from sparse matrix
*
* @param headers index to headers (rowHeaders for CSR or colHeaders for CSC)
* @param indices Indices (colIndexs for CSR or rowIndexs for CSC)
* @param data non zero values (sequence by row for CSR or by col for CSC)
* @param st sparse matrix type (CSR or CSC)
* @param shapeParam when st is CSR, it specifies the column number, otherwise it is taken as
* row number
*/
@throws(classOf[XGBoostError])
def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType,
shapeParam: Int) {
this(new JDMatrix(headers, indices, data, st, shapeParam))
}
/**
* create DMatrix from dense matrix
*