[JVM] Add Iterator loading API
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
package ml.dmlc.xgboost4j;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Iterator;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
@@ -47,6 +48,33 @@ public class DMatrix {
|
||||
CSC;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create DMatrix from iterator.
|
||||
*
|
||||
* @param iter The data iterator of mini batch to provide the data.
|
||||
* @param cache_info Cache path information, used for external memory setting, can be null.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public DMatrix(Iterator<DataBatch> iter, String cache_info) throws XGBoostError {
|
||||
if (iter == null) {
|
||||
throw new NullPointerException("iter: null");
|
||||
}
|
||||
try {
|
||||
logger.info(iter.getClass().getMethod("next").toString());
|
||||
} catch(NoSuchMethodException e) {
|
||||
logger.info(e.toString());
|
||||
}
|
||||
long[] out = new long[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromDataIter(iter, cache_info, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Create DMatrix by loading libsvm file from dataPath
|
||||
*
|
||||
* @param dataPath The path to the data.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public DMatrix(String dataPath) throws XGBoostError {
|
||||
if (dataPath == null) {
|
||||
throw new NullPointerException("dataPath: null");
|
||||
@@ -56,6 +84,14 @@ public class DMatrix {
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Create DMatrix from Sparse matrix in CSR/CSC format.
|
||||
* @param headers The row index of the matrix.
|
||||
* @param indices The indices of presenting entries.
|
||||
* @param data The data content.
|
||||
* @param st Type of sparsity.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
|
||||
long[] out = new long[1];
|
||||
if (st == SparseType.CSR) {
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
package ml.dmlc.xgboost4j;
|
||||
|
||||
/**
|
||||
* A mini-batch of data that can be converted to DMatrix.
|
||||
* The data is in sparse matrix CSR format.
|
||||
*
|
||||
* Usually this object is not needed.
|
||||
*
|
||||
* This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
|
||||
*/
|
||||
public class DataBatch {
|
||||
/** The offset of each rows in the sparse matrix */
|
||||
long[] rowOffset = null;
|
||||
/** weight of each data point, can be null */
|
||||
float[] weight = null;
|
||||
/** label of each data point, can be null */
|
||||
float[] label = null;
|
||||
/** index of each feature(column) in the sparse matrix */
|
||||
int[] featureIndex = null;
|
||||
/** value of each non-missing entry in the sparse matrix */
|
||||
float[] featureValue = null;
|
||||
/**
|
||||
* Get number of rows in the data batch.
|
||||
* @return Number of rows in the data batch.
|
||||
*/
|
||||
public int numRows() {
|
||||
return rowOffset.length - 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Shallow copy a DataBatch
|
||||
* @return a copy of the batch
|
||||
*/
|
||||
public DataBatch shallowCopy() {
|
||||
DataBatch b = new DataBatch();
|
||||
b.rowOffset = this.rowOffset;
|
||||
b.weight = this.weight;
|
||||
b.label = this.label;
|
||||
b.featureIndex = this.featureIndex;
|
||||
b.featureValue = this.featureValue;
|
||||
return b;
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,7 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j;
|
||||
|
||||
|
||||
/**
|
||||
* xgboost JNI functions
|
||||
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
|
||||
@@ -26,6 +27,8 @@ class XgboostJNI {
|
||||
|
||||
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter, String cache_info, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data,
|
||||
long[] out);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user