[JVM] Add Iterator loading API

This commit is contained in:
tqchen
2016-03-04 17:22:08 -08:00
parent 770b3451ca
commit 86871d4be9
10 changed files with 451 additions and 5 deletions

View File

@@ -16,6 +16,7 @@
package ml.dmlc.xgboost4j;
import java.io.IOException;
import java.util.Iterator;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -47,6 +48,33 @@ public class DMatrix {
CSC;
}
/**
* Create DMatrix from iterator.
*
* @param iter The data iterator of mini batch to provide the data.
* @param cache_info Cache path information, used for external memory setting, can be null.
* @throws XGBoostError
*/
public DMatrix(Iterator<DataBatch> iter, String cache_info) throws XGBoostError {
if (iter == null) {
throw new NullPointerException("iter: null");
}
try {
logger.info(iter.getClass().getMethod("next").toString());
} catch(NoSuchMethodException e) {
logger.info(e.toString());
}
long[] out = new long[1];
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromDataIter(iter, cache_info, out));
handle = out[0];
}
/**
* Create DMatrix by loading libsvm file from dataPath
*
* @param dataPath The path to the data.
* @throws XGBoostError
*/
public DMatrix(String dataPath) throws XGBoostError {
if (dataPath == null) {
throw new NullPointerException("dataPath: null");
@@ -56,6 +84,14 @@ public class DMatrix {
handle = out[0];
}
/**
* Create DMatrix from Sparse matrix in CSR/CSC format.
* @param headers The row index of the matrix.
* @param indices The indices of presenting entries.
* @param data The data content.
* @param st Type of sparsity.
* @throws XGBoostError
*/
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
long[] out = new long[1];
if (st == SparseType.CSR) {

View File

@@ -0,0 +1,43 @@
package ml.dmlc.xgboost4j;
/**
* A mini-batch of data that can be converted to DMatrix.
* The data is in sparse matrix CSR format.
*
* Usually this object is not needed.
*
* This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
*/
public class DataBatch {
/** The offset of each rows in the sparse matrix */
long[] rowOffset = null;
/** weight of each data point, can be null */
float[] weight = null;
/** label of each data point, can be null */
float[] label = null;
/** index of each feature(column) in the sparse matrix */
int[] featureIndex = null;
/** value of each non-missing entry in the sparse matrix */
float[] featureValue = null;
/**
* Get number of rows in the data batch.
* @return Number of rows in the data batch.
*/
public int numRows() {
return rowOffset.length - 1;
}
/**
* Shallow copy a DataBatch
* @return a copy of the batch
*/
public DataBatch shallowCopy() {
DataBatch b = new DataBatch();
b.rowOffset = this.rowOffset;
b.weight = this.weight;
b.label = this.label;
b.featureIndex = this.featureIndex;
b.featureValue = this.featureValue;
return b;
}
}

View File

@@ -15,6 +15,7 @@
*/
package ml.dmlc.xgboost4j;
/**
* xgboost JNI functions
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
@@ -26,6 +27,8 @@ class XgboostJNI {
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
public final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter, String cache_info, long[] out);
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data,
long[] out);