[JVM] Add Iterator loading API

This commit is contained in:
tqchen
2016-03-04 17:22:08 -08:00
parent 770b3451ca
commit 86871d4be9
10 changed files with 451 additions and 5 deletions

View File

@@ -16,6 +16,7 @@
package ml.dmlc.xgboost4j;
import java.io.IOException;
import java.util.Iterator;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -47,6 +48,33 @@ public class DMatrix {
CSC;
}
/**
* Create DMatrix from iterator.
*
* @param iter The data iterator of mini batch to provide the data.
* @param cache_info Cache path information, used for external memory setting, can be null.
* @throws XGBoostError
*/
public DMatrix(Iterator<DataBatch> iter, String cache_info) throws XGBoostError {
if (iter == null) {
throw new NullPointerException("iter: null");
}
try {
logger.info(iter.getClass().getMethod("next").toString());
} catch(NoSuchMethodException e) {
logger.info(e.toString());
}
long[] out = new long[1];
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromDataIter(iter, cache_info, out));
handle = out[0];
}
/**
* Create DMatrix by loading libsvm file from dataPath
*
* @param dataPath The path to the data.
* @throws XGBoostError
*/
public DMatrix(String dataPath) throws XGBoostError {
if (dataPath == null) {
throw new NullPointerException("dataPath: null");
@@ -56,6 +84,14 @@ public class DMatrix {
handle = out[0];
}
/**
* Create DMatrix from Sparse matrix in CSR/CSC format.
* @param headers The row index of the matrix.
* @param indices The indices of presenting entries.
* @param data The data content.
* @param st Type of sparsity.
* @throws XGBoostError
*/
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
long[] out = new long[1];
if (st == SparseType.CSR) {

View File

@@ -0,0 +1,43 @@
package ml.dmlc.xgboost4j;
/**
* A mini-batch of data that can be converted to DMatrix.
* The data is in sparse matrix CSR format.
*
* Usually this object is not needed.
*
* This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
*/
public class DataBatch {
/** The offset of each rows in the sparse matrix */
long[] rowOffset = null;
/** weight of each data point, can be null */
float[] weight = null;
/** label of each data point, can be null */
float[] label = null;
/** index of each feature(column) in the sparse matrix */
int[] featureIndex = null;
/** value of each non-missing entry in the sparse matrix */
float[] featureValue = null;
/**
* Get number of rows in the data batch.
* @return Number of rows in the data batch.
*/
public int numRows() {
return rowOffset.length - 1;
}
/**
* Shallow copy a DataBatch
* @return a copy of the batch
*/
public DataBatch shallowCopy() {
DataBatch b = new DataBatch();
b.rowOffset = this.rowOffset;
b.weight = this.weight;
b.label = this.label;
b.featureIndex = this.featureIndex;
b.featureValue = this.featureValue;
return b;
}
}

View File

@@ -15,6 +15,7 @@
*/
package ml.dmlc.xgboost4j;
/**
* xgboost JNI functions
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
@@ -26,6 +27,8 @@ class XgboostJNI {
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
public final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter, String cache_info, long[] out);
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data,
long[] out);

View File

@@ -20,13 +20,124 @@
#include <vector>
#include <string>
//helper functions
//set handle
// helper functions
// set handle
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
long out = (long) handle;
jenv->SetLongArrayRegion(jhandle, 0, 1, &out);
}
// global JVM
static JavaVM* global_jvm = nullptr;
// overrides JNI on load
jint JNI_OnLoad(JavaVM *vm, void *reserved) {
global_jvm = vm;
return JNI_VERSION_1_6;
}
XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
DataIterHandle data_handle,
XGBCallbackSetData* set_function,
DataHolderHandle set_function_handle) {
jobject jiter = static_cast<jobject>(data_handle);
JNIEnv* jenv;
int jni_status = global_jvm->GetEnv((void **)&jenv, JNI_VERSION_1_6);
if (jni_status == JNI_EDETACHED) {
global_jvm->AttachCurrentThread(reinterpret_cast<void **>(&jenv), nullptr);
} else {
CHECK(jni_status == JNI_OK);
}
try {
jclass iterClass = jenv->FindClass("java/util/Iterator");
jmethodID hasNext = jenv->GetMethodID(iterClass,
"hasNext", "()Z");
jmethodID next = jenv->GetMethodID(iterClass,
"next", "()Ljava/lang/Object;");
int ret_value;
if (jenv->CallBooleanMethod(jiter, hasNext)) {
ret_value = 1;
jobject batch = jenv->CallObjectMethod(jiter, next);
jclass batchClass = jenv->GetObjectClass(batch);
jlongArray joffset = (jlongArray)jenv->GetObjectField(
batch, jenv->GetFieldID(batchClass, "rowOffset", "[J"));
jfloatArray jlabel = (jfloatArray)jenv->GetObjectField(
batch, jenv->GetFieldID(batchClass, "label", "[F"));
jfloatArray jweight = (jfloatArray)jenv->GetObjectField(
batch, jenv->GetFieldID(batchClass, "weight", "[F"));
jintArray jindex = (jintArray)jenv->GetObjectField(
batch, jenv->GetFieldID(batchClass, "featureIndex", "[I"));
jfloatArray jvalue = (jfloatArray)jenv->GetObjectField(
batch, jenv->GetFieldID(batchClass, "featureValue", "[F"));
XGBoostBatchCSR cbatch;
cbatch.size = jenv->GetArrayLength(joffset) - 1;
cbatch.offset = jenv->GetLongArrayElements(joffset, 0);
if (jlabel != nullptr) {
cbatch.label = jenv->GetFloatArrayElements(jlabel, 0);
CHECK_EQ(jenv->GetArrayLength(jlabel), static_cast<long>(cbatch.size))
<< "batch.label.length must equal batch.numRows()";
} else {
cbatch.label = nullptr;
}
if (jweight != nullptr) {
cbatch.weight = jenv->GetFloatArrayElements(jweight, 0);
CHECK_EQ(jenv->GetArrayLength(jweight), static_cast<long>(cbatch.size))
<< "batch.weight.length must equal batch.numRows()";
} else {
cbatch.weight = nullptr;
}
long max_elem = cbatch.offset[cbatch.size];
cbatch.index = jenv->GetIntArrayElements(jindex, 0);
cbatch.value = jenv->GetFloatArrayElements(jvalue, 0);
CHECK_EQ(jenv->GetArrayLength(jindex), max_elem)
<< "batch.index.length must equal batch.offset.back()";
CHECK_EQ(jenv->GetArrayLength(jvalue), max_elem)
<< "batch.index.length must equal batch.offset.back()";
// cbatch is ready
CHECK_EQ((*set_function)(set_function_handle, cbatch), 0)
<< XGBGetLastError();
// release the elements.
jenv->ReleaseLongArrayElements(joffset, cbatch.offset, 0);
jenv->DeleteLocalRef(joffset);
if (jlabel != nullptr) {
jenv->ReleaseFloatArrayElements(jlabel, cbatch.label, 0);
jenv->DeleteLocalRef(jlabel);
}
if (jweight != nullptr) {
jenv->ReleaseFloatArrayElements(jweight, cbatch.weight, 0);
jenv->DeleteLocalRef(jweight);
}
jenv->ReleaseIntArrayElements(jindex, cbatch.index, 0);
jenv->DeleteLocalRef(jindex);
jenv->ReleaseFloatArrayElements(jvalue, cbatch.value, 0);
jenv->DeleteLocalRef(jvalue);
jenv->DeleteLocalRef(batch);
jenv->DeleteLocalRef(batchClass);
ret_value = 1;
} else {
ret_value = 0;
}
jenv->DeleteLocalRef(iterClass);
// only detach if it is a async call.
if (jni_status == JNI_EDETACHED) {
global_jvm->DetachCurrentThread();
}
return ret_value;
} catch(dmlc::Error e) {
// only detach if it is a async call.
if (jni_status == JNI_EDETACHED) {
global_jvm->DetachCurrentThread();
}
LOG(FATAL) << e.what();
return -1;
}
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBGetLastError
* Signature: ()Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
(JNIEnv *jenv, jclass jcls) {
jstring jresult = 0;
@@ -37,6 +148,32 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
return jresult;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixCreateFromDataIter
* Signature: (Ljava/util/Iterator;Ljava/lang/String;[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromDataIter
(JNIEnv *jenv, jclass jcls, jobject jiter, jstring jcache_info, jlongArray jout) {
DMatrixHandle result;
const char* cache_info = nullptr;
if (jcache_info != nullptr) {
cache_info = jenv->GetStringUTFChars(jcache_info, 0);
}
int ret = XGDMatrixCreateFromDataIter(
jiter, XGBoost4jCallbackDataIterNext, cache_info, &result);
if (cache_info) {
jenv->ReleaseStringUTFChars(jcache_info, cache_info);
}
setHandle(jenv, jout, result);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixCreateFromFile
* Signature: (Ljava/lang/String;I[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
(JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) {
DMatrixHandle result;

View File

@@ -23,6 +23,14 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
(JNIEnv *, jclass, jstring, jint, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixCreateFromDataIter
* Signature: (Ljava/util/Iterator;Ljava/lang/String;[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromDataIter
(JNIEnv *, jclass, jobject, jstring, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixCreateFromCSR

View File

@@ -28,6 +28,41 @@ import org.junit.Test;
*/
public class DMatrixTest {
@Test
public void testCreateFromDataIterator() throws XGBoostError {
//create DMatrix from DataIterator
/**
* sparse matrix
* 1 0 2 3 0
* 4 0 2 3 5
* 3 1 2 5 0
*/
DataBatch batch = new DataBatch();
batch.featureIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
batch.featureValue = new float[]{1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5};
batch.rowOffset = new long[]{0, 3, 7, 11};
batch.label = new float[] {0.1f, 0.2f, 0.3f};
java.util.ArrayList<Float> labelall = new java.util.ArrayList<Float>();
int nrep = 3;
java.util.List<DataBatch> blist = new java.util.LinkedList<DataBatch>();
for (int i = 0; i < nrep; ++i) {
batch.label = new float[] {0.1f+i, 0.2f+i, 0.3f+i};
blist.add(batch.shallowCopy());
for (float f : batch.label) {
labelall.add(f);
}
}
DMatrix dmat = new DMatrix(blist.iterator(), null);
// get label
float[] labels = dmat.getLabel();
// get label
TestCase.assertTrue(batch.label.length * nrep == labels.length);
for (int i = 0; i < labels.length; ++i) {
TestCase.assertTrue(labelall.get(i) == labels[i]);
}
}
@Test
public void testCreateFromFile() throws XGBoostError {
//create DMatrix from file