[jvm-packages] Robust dmatrix creation (#1613)
* add back train method but mark as deprecated * robust matrix creation in jvm
This commit is contained in:
@@ -92,12 +92,38 @@ public class DMatrix {
|
||||
* @param st Type of sparsity.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
@Deprecated
|
||||
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
|
||||
long[] out = new long[1];
|
||||
if (st == SparseType.CSR) {
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data, 0, out));
|
||||
} else if (st == SparseType.CSC) {
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data, 0, out));
|
||||
} else {
|
||||
throw new UnknownError("unknow sparsetype");
|
||||
}
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Create DMatrix from Sparse matrix in CSR/CSC format.
|
||||
* @param headers The row index of the matrix.
|
||||
* @param indices The indices of presenting entries.
|
||||
* @param data The data content.
|
||||
* @param st Type of sparsity.
|
||||
* @param shapeParam when st is CSR, it specifies the column number, otherwise it is taken as
|
||||
* row number
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st, int shapeParam)
|
||||
throws XGBoostError {
|
||||
long[] out = new long[1];
|
||||
if (st == SparseType.CSR) {
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data,
|
||||
shapeParam, out));
|
||||
} else if (st == SparseType.CSC) {
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data,
|
||||
shapeParam, out));
|
||||
} else {
|
||||
throw new UnknownError("unknow sparsetype");
|
||||
}
|
||||
|
||||
@@ -30,11 +30,11 @@ class XGBoostJNI {
|
||||
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
|
||||
String cache_info, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data,
|
||||
long[] out);
|
||||
public final static native int XGDMatrixCreateFromCSREx(long[] indptr, int[] indices, float[] data,
|
||||
int shapeParam, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data,
|
||||
long[] out);
|
||||
public final static native int XGDMatrixCreateFromCSCEx(long[] colptr, int[] indices, float[] data,
|
||||
int shapeParam, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol,
|
||||
float missing, long[] out);
|
||||
|
||||
@@ -51,10 +51,27 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
* @param st sparse matrix type (CSR or CSC)
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
@deprecated
|
||||
def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType) {
|
||||
this(new JDMatrix(headers, indices, data, st))
|
||||
}
|
||||
|
||||
/**
|
||||
* create DMatrix from sparse matrix
|
||||
*
|
||||
* @param headers index to headers (rowHeaders for CSR or colHeaders for CSC)
|
||||
* @param indices Indices (colIndexs for CSR or rowIndexs for CSC)
|
||||
* @param data non zero values (sequence by row for CSR or by col for CSC)
|
||||
* @param st sparse matrix type (CSR or CSC)
|
||||
* @param shapeParam when st is CSR, it specifies the column number, otherwise it is taken as
|
||||
* row number
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType,
|
||||
shapeParam: Int) {
|
||||
this(new JDMatrix(headers, indices, data, st, shapeParam))
|
||||
}
|
||||
|
||||
/**
|
||||
* create DMatrix from dense matrix
|
||||
*
|
||||
|
||||
Reference in New Issue
Block a user