Use array interface for CSC matrix. (#8672)

* Use array interface for CSC matrix.

Use array interface for CSC matrix and align the interface with CSR and dense.

- Fix nthread issue in the R package DMatrix.
- Unify the behavior of handling `missing` with other inputs.
- Unify the behavior of handling `missing` around R, Python, Java, and Scala DMatrix.
- Expose `num_non_missing` to the JVM interface.
- Deprecate old CSR and CSC constructors.
This commit is contained in:
Jiaming Yuan
2023-02-05 01:59:46 +08:00
committed by GitHub
parent 213b5602d9
commit c1786849e3
23 changed files with 673 additions and 380 deletions

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -79,17 +79,9 @@ public class DMatrix {
* @throws XGBoostError
*/
@Deprecated
public DMatrix(long[] headers, int[] indices, float[] data, DMatrix.SparseType st)
throws XGBoostError {
long[] out = new long[1];
if (st == SparseType.CSR) {
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data, 0, out));
} else if (st == SparseType.CSC) {
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data, 0, out));
} else {
throw new UnknownError("unknow sparsetype");
}
handle = out[0];
public DMatrix(long[] headers, int[] indices, float[] data,
DMatrix.SparseType st) throws XGBoostError {
this(headers, indices, data, st, 0, Float.NaN, -1);
}
/**
@@ -102,15 +94,20 @@ public class DMatrix {
* row number
* @throws XGBoostError
*/
public DMatrix(long[] headers, int[] indices, float[] data, DMatrix.SparseType st, int shapeParam)
throws XGBoostError {
public DMatrix(long[] headers, int[] indices, float[] data, DMatrix.SparseType st,
int shapeParam) throws XGBoostError {
this(headers, indices, data, st, shapeParam, Float.NaN, -1);
}
public DMatrix(long[] headers, int[] indices, float[] data, DMatrix.SparseType st, int shapeParam,
float missing, int nthread) throws XGBoostError {
long[] out = new long[1];
if (st == SparseType.CSR) {
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data,
shapeParam, out));
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSR(headers, indices, data,
shapeParam, missing, nthread, out));
} else if (st == SparseType.CSC) {
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data,
shapeParam, out));
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSC(headers, indices, data,
shapeParam, missing, nthread, out));
} else {
throw new UnknownError("unknow sparsetype");
}
@@ -425,6 +422,18 @@ public class DMatrix {
return rowNum[0];
}
/**
* Get the number of non-missing values of DMatrix.
*
* @return The number of non-missing values
* @throws XGBoostError native error
*/
public long nonMissingNum() throws XGBoostError {
long[] n = new long[1];
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixNumNonMissing(handle, n));
return n[0];
}
/**
* save DMatrix to filePath
*/

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -56,11 +56,15 @@ class XGBoostJNI {
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
String cache_info, long[] out);
public final static native int XGDMatrixCreateFromCSREx(long[] indptr, int[] indices, float[] data,
int shapeParam, long[] out);
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices,
float[] data, int shapeParam,
float missing, int nthread,
long[] out);
public final static native int XGDMatrixCreateFromCSCEx(long[] colptr, int[] indices, float[] data,
int shapeParam, long[] out);
public final static native int XGDMatrixCreateFromCSC(long[] colptr, int[] indices,
float[] data, int shapeParam,
float missing, int nthread,
long[] out);
public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol,
float missing, long[] out);
@@ -96,6 +100,7 @@ class XGBoostJNI {
long[] outLength, String[][] outValues);
public final static native int XGDMatrixNumRow(long handle, long[] row);
public final static native int XGDMatrixNumNonMissing(long handle, long[] nonMissings);
public final static native int XGBoosterCreate(long[] handles, long[] out);

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014,2021 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -54,7 +54,7 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
@throws(classOf[XGBoostError])
@deprecated
def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType) {
this(new JDMatrix(headers, indices, data, st))
this(new JDMatrix(headers, indices, data, st, 0, Float.NaN, -1))
}
/**
@@ -70,7 +70,25 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
@throws(classOf[XGBoostError])
def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType,
shapeParam: Int) {
this(new JDMatrix(headers, indices, data, st, shapeParam))
this(new JDMatrix(headers, indices, data, st, shapeParam, Float.NaN, -1))
}
/**
* create DMatrix from sparse matrix
*
* @param headers index to headers (rowHeaders for CSR or colHeaders for CSC)
* @param indices Indices (colIndexs for CSR or rowIndexs for CSC)
* @param data non zero values (sequence by row for CSR or by col for CSC)
* @param st sparse matrix type (CSR or CSC)
* @param shapeParam when st is CSR, it specifies the column number, otherwise it is taken as
* row number
* @param missing missing value
* @param nthread The number of threads used for constructing DMatrix
*/
@throws(classOf[XGBoostError])
def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType,
shapeParam: Int, missing: Float, nthread: Int) {
this(new JDMatrix(headers, indices, data, st, shapeParam, missing, nthread))
}
/**
@@ -78,7 +96,7 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
* @param columnBatch the XGBoost ColumnBatch to provide the cuda array interface
* of feature columns
* @param missing missing value
* @param nthread threads number
* @param nthread The number of threads used for constructing DMatrix
*/
@throws(classOf[XGBoostError])
def this(columnBatch: ColumnBatch, missing: Float, nthread: Int) {
@@ -246,6 +264,16 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
jDMatrix.rowNum
}
/**
* Get the number of non-missing values of DMatrix.
*
* @return The number of non-missing values
*/
@throws(classOf[XGBoostError])
def nonMissingNum: Long = {
jDMatrix.nonMissingNum
}
/**
* save DMatrix to filePath
*