[JVM] Add Iterator loading API
This commit is contained in:
parent
770b3451ca
commit
86871d4be9
@ -1 +1 @@
|
|||||||
Subproject commit 71360023dba458bdc9f1bc6f4309c1a107cb83a0
|
Subproject commit 3f6ff43d3976d5b6d5001608b0e3e526ecde098f
|
||||||
@ -12,6 +12,9 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// XGBoost C API will include APIs in Rabit C API
|
// XGBoost C API will include APIs in Rabit C API
|
||||||
|
XGB_EXTERN_C {
|
||||||
|
#include <stdio.h>
|
||||||
|
}
|
||||||
#include <rabit/c_api.h>
|
#include <rabit/c_api.h>
|
||||||
|
|
||||||
#if defined(_MSC_VER) || defined(_WIN32)
|
#if defined(_MSC_VER) || defined(_WIN32)
|
||||||
@ -26,6 +29,51 @@ typedef unsigned long bst_ulong; // NOLINT(*)
|
|||||||
typedef void *DMatrixHandle;
|
typedef void *DMatrixHandle;
|
||||||
/*! \brief handle to Booster */
|
/*! \brief handle to Booster */
|
||||||
typedef void *BoosterHandle;
|
typedef void *BoosterHandle;
|
||||||
|
/*! \brief handle to a data iterator */
|
||||||
|
typedef void *DataIterHandle;
|
||||||
|
/*! \brief handle to a internal data holder. */
|
||||||
|
typedef void *DataHolderHandle;
|
||||||
|
|
||||||
|
/*! \brief Mini batch used in XGBoost Data Iteration */
|
||||||
|
typedef struct {
|
||||||
|
/*! \brief number of rows in the minibatch */
|
||||||
|
size_t size;
|
||||||
|
/*! \brief row pointer to the rows in the data */
|
||||||
|
long* offset; // NOLINT(*)
|
||||||
|
/*! \brief labels of each instance */
|
||||||
|
float* label;
|
||||||
|
/*! \brief weight of each instance, can be NULL */
|
||||||
|
float* weight;
|
||||||
|
/*! \brief feature index */
|
||||||
|
int* index;
|
||||||
|
/*! \brief feature values */
|
||||||
|
float* value;
|
||||||
|
} XGBoostBatchCSR;
|
||||||
|
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Callback to set the data to handle,
|
||||||
|
* \param handle The handle to the callback.
|
||||||
|
* \param batch The data content to be setted.
|
||||||
|
*/
|
||||||
|
XGB_EXTERN_C typedef int XGBCallbackSetData(
|
||||||
|
DataHolderHandle handle, XGBoostBatchCSR batch);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief The data reading callback function.
|
||||||
|
* The iterator will be able to give subset of batch in the data.
|
||||||
|
*
|
||||||
|
* If there is data, the function will call set_function to set the data.
|
||||||
|
*
|
||||||
|
* \param data_handle The handle to the callback.
|
||||||
|
* \param set_function The batch returned by the iterator
|
||||||
|
* \param set_function_handle The handle to be passed to set function.
|
||||||
|
* \return 0 if we are reaching the end and batch is not returned.
|
||||||
|
*/
|
||||||
|
XGB_EXTERN_C typedef int XGBCallbackDataIterNext(
|
||||||
|
DataIterHandle data_handle,
|
||||||
|
XGBCallbackSetData* set_function,
|
||||||
|
DataHolderHandle set_function_handle);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief get string message of the last error
|
* \brief get string message of the last error
|
||||||
@ -50,6 +98,20 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname,
|
|||||||
int silent,
|
int silent,
|
||||||
DMatrixHandle *out);
|
DMatrixHandle *out);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Create a DMatrix from a data iterator.
|
||||||
|
* \param data_handle The handle to the data.
|
||||||
|
* \param callback The callback to get the data.
|
||||||
|
* \param cache_info Additional information about cache file, can be null.
|
||||||
|
* \param out The created DMatrix
|
||||||
|
* \return 0 when success, -1 when failure happens.
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGDMatrixCreateFromDataIter(
|
||||||
|
DataIterHandle data_handle,
|
||||||
|
XGBCallbackDataIterNext* callback,
|
||||||
|
const char* cache_info,
|
||||||
|
DMatrixHandle *out);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief create a matrix content from csr format
|
* \brief create a matrix content from csr format
|
||||||
* \param indptr pointer to row headers
|
* \param indptr pointer to row headers
|
||||||
|
|||||||
@ -16,6 +16,7 @@
|
|||||||
package ml.dmlc.xgboost4j;
|
package ml.dmlc.xgboost4j;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.Iterator;
|
||||||
|
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
@ -47,6 +48,33 @@ public class DMatrix {
|
|||||||
CSC;
|
CSC;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create DMatrix from iterator.
|
||||||
|
*
|
||||||
|
* @param iter The data iterator of mini batch to provide the data.
|
||||||
|
* @param cache_info Cache path information, used for external memory setting, can be null.
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public DMatrix(Iterator<DataBatch> iter, String cache_info) throws XGBoostError {
|
||||||
|
if (iter == null) {
|
||||||
|
throw new NullPointerException("iter: null");
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
logger.info(iter.getClass().getMethod("next").toString());
|
||||||
|
} catch(NoSuchMethodException e) {
|
||||||
|
logger.info(e.toString());
|
||||||
|
}
|
||||||
|
long[] out = new long[1];
|
||||||
|
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromDataIter(iter, cache_info, out));
|
||||||
|
handle = out[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create DMatrix by loading libsvm file from dataPath
|
||||||
|
*
|
||||||
|
* @param dataPath The path to the data.
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
public DMatrix(String dataPath) throws XGBoostError {
|
public DMatrix(String dataPath) throws XGBoostError {
|
||||||
if (dataPath == null) {
|
if (dataPath == null) {
|
||||||
throw new NullPointerException("dataPath: null");
|
throw new NullPointerException("dataPath: null");
|
||||||
@ -56,6 +84,14 @@ public class DMatrix {
|
|||||||
handle = out[0];
|
handle = out[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create DMatrix from Sparse matrix in CSR/CSC format.
|
||||||
|
* @param headers The row index of the matrix.
|
||||||
|
* @param indices The indices of presenting entries.
|
||||||
|
* @param data The data content.
|
||||||
|
* @param st Type of sparsity.
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
|
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
|
||||||
long[] out = new long[1];
|
long[] out = new long[1];
|
||||||
if (st == SparseType.CSR) {
|
if (st == SparseType.CSR) {
|
||||||
|
|||||||
@ -0,0 +1,43 @@
|
|||||||
|
package ml.dmlc.xgboost4j;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A mini-batch of data that can be converted to DMatrix.
|
||||||
|
* The data is in sparse matrix CSR format.
|
||||||
|
*
|
||||||
|
* Usually this object is not needed.
|
||||||
|
*
|
||||||
|
* This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
|
||||||
|
*/
|
||||||
|
public class DataBatch {
|
||||||
|
/** The offset of each rows in the sparse matrix */
|
||||||
|
long[] rowOffset = null;
|
||||||
|
/** weight of each data point, can be null */
|
||||||
|
float[] weight = null;
|
||||||
|
/** label of each data point, can be null */
|
||||||
|
float[] label = null;
|
||||||
|
/** index of each feature(column) in the sparse matrix */
|
||||||
|
int[] featureIndex = null;
|
||||||
|
/** value of each non-missing entry in the sparse matrix */
|
||||||
|
float[] featureValue = null;
|
||||||
|
/**
|
||||||
|
* Get number of rows in the data batch.
|
||||||
|
* @return Number of rows in the data batch.
|
||||||
|
*/
|
||||||
|
public int numRows() {
|
||||||
|
return rowOffset.length - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Shallow copy a DataBatch
|
||||||
|
* @return a copy of the batch
|
||||||
|
*/
|
||||||
|
public DataBatch shallowCopy() {
|
||||||
|
DataBatch b = new DataBatch();
|
||||||
|
b.rowOffset = this.rowOffset;
|
||||||
|
b.weight = this.weight;
|
||||||
|
b.label = this.label;
|
||||||
|
b.featureIndex = this.featureIndex;
|
||||||
|
b.featureValue = this.featureValue;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -15,6 +15,7 @@
|
|||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j;
|
package ml.dmlc.xgboost4j;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* xgboost JNI functions
|
* xgboost JNI functions
|
||||||
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
|
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
|
||||||
@ -26,6 +27,8 @@ class XgboostJNI {
|
|||||||
|
|
||||||
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
|
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
|
||||||
|
|
||||||
|
public final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter, String cache_info, long[] out);
|
||||||
|
|
||||||
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data,
|
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data,
|
||||||
long[] out);
|
long[] out);
|
||||||
|
|
||||||
|
|||||||
@ -20,13 +20,124 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
//helper functions
|
// helper functions
|
||||||
//set handle
|
// set handle
|
||||||
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
|
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
|
||||||
long out = (long) handle;
|
long out = (long) handle;
|
||||||
jenv->SetLongArrayRegion(jhandle, 0, 1, &out);
|
jenv->SetLongArrayRegion(jhandle, 0, 1, &out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// global JVM
|
||||||
|
static JavaVM* global_jvm = nullptr;
|
||||||
|
|
||||||
|
// overrides JNI on load
|
||||||
|
jint JNI_OnLoad(JavaVM *vm, void *reserved) {
|
||||||
|
global_jvm = vm;
|
||||||
|
return JNI_VERSION_1_6;
|
||||||
|
}
|
||||||
|
|
||||||
|
XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
|
||||||
|
DataIterHandle data_handle,
|
||||||
|
XGBCallbackSetData* set_function,
|
||||||
|
DataHolderHandle set_function_handle) {
|
||||||
|
jobject jiter = static_cast<jobject>(data_handle);
|
||||||
|
JNIEnv* jenv;
|
||||||
|
int jni_status = global_jvm->GetEnv((void **)&jenv, JNI_VERSION_1_6);
|
||||||
|
if (jni_status == JNI_EDETACHED) {
|
||||||
|
global_jvm->AttachCurrentThread(reinterpret_cast<void **>(&jenv), nullptr);
|
||||||
|
} else {
|
||||||
|
CHECK(jni_status == JNI_OK);
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
jclass iterClass = jenv->FindClass("java/util/Iterator");
|
||||||
|
jmethodID hasNext = jenv->GetMethodID(iterClass,
|
||||||
|
"hasNext", "()Z");
|
||||||
|
jmethodID next = jenv->GetMethodID(iterClass,
|
||||||
|
"next", "()Ljava/lang/Object;");
|
||||||
|
int ret_value;
|
||||||
|
if (jenv->CallBooleanMethod(jiter, hasNext)) {
|
||||||
|
ret_value = 1;
|
||||||
|
jobject batch = jenv->CallObjectMethod(jiter, next);
|
||||||
|
jclass batchClass = jenv->GetObjectClass(batch);
|
||||||
|
jlongArray joffset = (jlongArray)jenv->GetObjectField(
|
||||||
|
batch, jenv->GetFieldID(batchClass, "rowOffset", "[J"));
|
||||||
|
jfloatArray jlabel = (jfloatArray)jenv->GetObjectField(
|
||||||
|
batch, jenv->GetFieldID(batchClass, "label", "[F"));
|
||||||
|
jfloatArray jweight = (jfloatArray)jenv->GetObjectField(
|
||||||
|
batch, jenv->GetFieldID(batchClass, "weight", "[F"));
|
||||||
|
jintArray jindex = (jintArray)jenv->GetObjectField(
|
||||||
|
batch, jenv->GetFieldID(batchClass, "featureIndex", "[I"));
|
||||||
|
jfloatArray jvalue = (jfloatArray)jenv->GetObjectField(
|
||||||
|
batch, jenv->GetFieldID(batchClass, "featureValue", "[F"));
|
||||||
|
XGBoostBatchCSR cbatch;
|
||||||
|
cbatch.size = jenv->GetArrayLength(joffset) - 1;
|
||||||
|
cbatch.offset = jenv->GetLongArrayElements(joffset, 0);
|
||||||
|
if (jlabel != nullptr) {
|
||||||
|
cbatch.label = jenv->GetFloatArrayElements(jlabel, 0);
|
||||||
|
CHECK_EQ(jenv->GetArrayLength(jlabel), static_cast<long>(cbatch.size))
|
||||||
|
<< "batch.label.length must equal batch.numRows()";
|
||||||
|
} else {
|
||||||
|
cbatch.label = nullptr;
|
||||||
|
}
|
||||||
|
if (jweight != nullptr) {
|
||||||
|
cbatch.weight = jenv->GetFloatArrayElements(jweight, 0);
|
||||||
|
CHECK_EQ(jenv->GetArrayLength(jweight), static_cast<long>(cbatch.size))
|
||||||
|
<< "batch.weight.length must equal batch.numRows()";
|
||||||
|
} else {
|
||||||
|
cbatch.weight = nullptr;
|
||||||
|
}
|
||||||
|
long max_elem = cbatch.offset[cbatch.size];
|
||||||
|
cbatch.index = jenv->GetIntArrayElements(jindex, 0);
|
||||||
|
cbatch.value = jenv->GetFloatArrayElements(jvalue, 0);
|
||||||
|
CHECK_EQ(jenv->GetArrayLength(jindex), max_elem)
|
||||||
|
<< "batch.index.length must equal batch.offset.back()";
|
||||||
|
CHECK_EQ(jenv->GetArrayLength(jvalue), max_elem)
|
||||||
|
<< "batch.index.length must equal batch.offset.back()";
|
||||||
|
// cbatch is ready
|
||||||
|
CHECK_EQ((*set_function)(set_function_handle, cbatch), 0)
|
||||||
|
<< XGBGetLastError();
|
||||||
|
// release the elements.
|
||||||
|
jenv->ReleaseLongArrayElements(joffset, cbatch.offset, 0);
|
||||||
|
jenv->DeleteLocalRef(joffset);
|
||||||
|
if (jlabel != nullptr) {
|
||||||
|
jenv->ReleaseFloatArrayElements(jlabel, cbatch.label, 0);
|
||||||
|
jenv->DeleteLocalRef(jlabel);
|
||||||
|
}
|
||||||
|
if (jweight != nullptr) {
|
||||||
|
jenv->ReleaseFloatArrayElements(jweight, cbatch.weight, 0);
|
||||||
|
jenv->DeleteLocalRef(jweight);
|
||||||
|
}
|
||||||
|
jenv->ReleaseIntArrayElements(jindex, cbatch.index, 0);
|
||||||
|
jenv->DeleteLocalRef(jindex);
|
||||||
|
jenv->ReleaseFloatArrayElements(jvalue, cbatch.value, 0);
|
||||||
|
jenv->DeleteLocalRef(jvalue);
|
||||||
|
jenv->DeleteLocalRef(batch);
|
||||||
|
jenv->DeleteLocalRef(batchClass);
|
||||||
|
ret_value = 1;
|
||||||
|
} else {
|
||||||
|
ret_value = 0;
|
||||||
|
}
|
||||||
|
jenv->DeleteLocalRef(iterClass);
|
||||||
|
// only detach if it is a async call.
|
||||||
|
if (jni_status == JNI_EDETACHED) {
|
||||||
|
global_jvm->DetachCurrentThread();
|
||||||
|
}
|
||||||
|
return ret_value;
|
||||||
|
} catch(dmlc::Error e) {
|
||||||
|
// only detach if it is a async call.
|
||||||
|
if (jni_status == JNI_EDETACHED) {
|
||||||
|
global_jvm->DetachCurrentThread();
|
||||||
|
}
|
||||||
|
LOG(FATAL) << e.what();
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: XGBGetLastError
|
||||||
|
* Signature: ()Ljava/lang/String;
|
||||||
|
*/
|
||||||
JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
|
JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
|
||||||
(JNIEnv *jenv, jclass jcls) {
|
(JNIEnv *jenv, jclass jcls) {
|
||||||
jstring jresult = 0;
|
jstring jresult = 0;
|
||||||
@ -37,6 +148,32 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
|
|||||||
return jresult;
|
return jresult;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: XGDMatrixCreateFromDataIter
|
||||||
|
* Signature: (Ljava/util/Iterator;Ljava/lang/String;[J)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromDataIter
|
||||||
|
(JNIEnv *jenv, jclass jcls, jobject jiter, jstring jcache_info, jlongArray jout) {
|
||||||
|
DMatrixHandle result;
|
||||||
|
const char* cache_info = nullptr;
|
||||||
|
if (jcache_info != nullptr) {
|
||||||
|
cache_info = jenv->GetStringUTFChars(jcache_info, 0);
|
||||||
|
}
|
||||||
|
int ret = XGDMatrixCreateFromDataIter(
|
||||||
|
jiter, XGBoost4jCallbackDataIterNext, cache_info, &result);
|
||||||
|
if (cache_info) {
|
||||||
|
jenv->ReleaseStringUTFChars(jcache_info, cache_info);
|
||||||
|
}
|
||||||
|
setHandle(jenv, jout, result);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: XGDMatrixCreateFromFile
|
||||||
|
* Signature: (Ljava/lang/String;I[J)I
|
||||||
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
|
||||||
(JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) {
|
(JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) {
|
||||||
DMatrixHandle result;
|
DMatrixHandle result;
|
||||||
|
|||||||
@ -23,6 +23,14 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
|
|||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
|
||||||
(JNIEnv *, jclass, jstring, jint, jlongArray);
|
(JNIEnv *, jclass, jstring, jint, jlongArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: XGDMatrixCreateFromDataIter
|
||||||
|
* Signature: (Ljava/util/Iterator;Ljava/lang/String;[J)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromDataIter
|
||||||
|
(JNIEnv *, jclass, jobject, jstring, jlongArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
* Method: XGDMatrixCreateFromCSR
|
* Method: XGDMatrixCreateFromCSR
|
||||||
|
|||||||
@ -28,6 +28,41 @@ import org.junit.Test;
|
|||||||
*/
|
*/
|
||||||
public class DMatrixTest {
|
public class DMatrixTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCreateFromDataIterator() throws XGBoostError {
|
||||||
|
//create DMatrix from DataIterator
|
||||||
|
/**
|
||||||
|
* sparse matrix
|
||||||
|
* 1 0 2 3 0
|
||||||
|
* 4 0 2 3 5
|
||||||
|
* 3 1 2 5 0
|
||||||
|
*/
|
||||||
|
DataBatch batch = new DataBatch();
|
||||||
|
batch.featureIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
|
||||||
|
batch.featureValue = new float[]{1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5};
|
||||||
|
batch.rowOffset = new long[]{0, 3, 7, 11};
|
||||||
|
batch.label = new float[] {0.1f, 0.2f, 0.3f};
|
||||||
|
java.util.ArrayList<Float> labelall = new java.util.ArrayList<Float>();
|
||||||
|
int nrep = 3;
|
||||||
|
java.util.List<DataBatch> blist = new java.util.LinkedList<DataBatch>();
|
||||||
|
for (int i = 0; i < nrep; ++i) {
|
||||||
|
batch.label = new float[] {0.1f+i, 0.2f+i, 0.3f+i};
|
||||||
|
blist.add(batch.shallowCopy());
|
||||||
|
for (float f : batch.label) {
|
||||||
|
labelall.add(f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DMatrix dmat = new DMatrix(blist.iterator(), null);
|
||||||
|
// get label
|
||||||
|
float[] labels = dmat.getLabel();
|
||||||
|
// get label
|
||||||
|
TestCase.assertTrue(batch.label.length * nrep == labels.length);
|
||||||
|
|
||||||
|
for (int i = 0; i < labels.length; ++i) {
|
||||||
|
TestCase.assertTrue(labelall.get(i) == labels[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCreateFromFile() throws XGBoostError {
|
public void testCreateFromFile() throws XGBoostError {
|
||||||
//create DMatrix from file
|
//create DMatrix from file
|
||||||
|
|||||||
2
rabit
2
rabit
@ -1 +1 @@
|
|||||||
Subproject commit 1392e9f3da59bd5602ddebee944dd8fb5c6507b0
|
Subproject commit be50e7b63224b9fb7ff94ce34df9f8752ef83043
|
||||||
@ -19,7 +19,6 @@
|
|||||||
#include "../common/group_data.h"
|
#include "../common/group_data.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
// booster wrapper for backward compatible reason.
|
// booster wrapper for backward compatible reason.
|
||||||
class Booster {
|
class Booster {
|
||||||
public:
|
public:
|
||||||
@ -61,6 +60,113 @@ class Booster {
|
|||||||
std::unique_ptr<Learner> learner_;
|
std::unique_ptr<Learner> learner_;
|
||||||
std::vector<std::pair<std::string, std::string> > cfg_;
|
std::vector<std::pair<std::string, std::string> > cfg_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// declare the data callback.
|
||||||
|
XGB_EXTERN_C int XGBoostNativeDataIterSetData(
|
||||||
|
void *handle, XGBoostBatchCSR batch);
|
||||||
|
|
||||||
|
/*! \brief Native data iterator that takes callback to return data */
|
||||||
|
class NativeDataIter : public dmlc::Parser<uint32_t> {
|
||||||
|
public:
|
||||||
|
NativeDataIter(DataIterHandle data_handle,
|
||||||
|
XGBCallbackDataIterNext* next_callback)
|
||||||
|
: at_first_(true), bytes_read_(0),
|
||||||
|
data_handle_(data_handle), next_callback_(next_callback) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// override functions
|
||||||
|
void BeforeFirst() override {
|
||||||
|
CHECK(at_first_) << "cannot reset NativeDataIter";
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Next() override {
|
||||||
|
if ((*next_callback_)(
|
||||||
|
data_handle_,
|
||||||
|
XGBoostNativeDataIterSetData,
|
||||||
|
this) != 0) {
|
||||||
|
at_first_ = false;
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const dmlc::RowBlock<uint32_t>& Value() const override {
|
||||||
|
return block_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t BytesRead() const override {
|
||||||
|
return bytes_read_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// callback to set the data
|
||||||
|
void SetData(const XGBoostBatchCSR& batch) {
|
||||||
|
offset_.clear();
|
||||||
|
label_.clear();
|
||||||
|
weight_.clear();
|
||||||
|
index_.clear();
|
||||||
|
value_.clear();
|
||||||
|
offset_.insert(offset_.end(), batch.offset, batch.offset + batch.size + 1);
|
||||||
|
if (batch.label != nullptr) {
|
||||||
|
label_.insert(label_.end(), batch.label, batch.label + batch.size);
|
||||||
|
}
|
||||||
|
if (batch.weight != nullptr) {
|
||||||
|
weight_.insert(weight_.end(), batch.weight, batch.weight + batch.size);
|
||||||
|
}
|
||||||
|
if (batch.index != nullptr) {
|
||||||
|
index_.insert(index_.end(), batch.index + offset_[0], batch.index + offset_.back());
|
||||||
|
}
|
||||||
|
if (batch.value != nullptr) {
|
||||||
|
value_.insert(value_.end(), batch.value + offset_[0], batch.value + offset_.back());
|
||||||
|
}
|
||||||
|
if (offset_[0] != 0) {
|
||||||
|
size_t base = offset_[0];
|
||||||
|
for (size_t& item : offset_) {
|
||||||
|
item -= base;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
block_.size = batch.size;
|
||||||
|
block_.offset = dmlc::BeginPtr(offset_);
|
||||||
|
block_.label = dmlc::BeginPtr(label_);
|
||||||
|
block_.weight = dmlc::BeginPtr(weight_);
|
||||||
|
block_.index = dmlc::BeginPtr(index_);
|
||||||
|
block_.value = dmlc::BeginPtr(value_);
|
||||||
|
bytes_read_ += offset_.size() * sizeof(size_t) +
|
||||||
|
label_.size() * sizeof(dmlc::real_t) +
|
||||||
|
weight_.size() * sizeof(dmlc::real_t) +
|
||||||
|
index_.size() * sizeof(uint32_t) +
|
||||||
|
value_.size() * sizeof(dmlc::real_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// at the beinning.
|
||||||
|
bool at_first_;
|
||||||
|
// bytes that is read.
|
||||||
|
size_t bytes_read_;
|
||||||
|
// handle to the iterator,
|
||||||
|
DataIterHandle data_handle_;
|
||||||
|
// call back to get the data.
|
||||||
|
XGBCallbackDataIterNext* next_callback_;
|
||||||
|
// internal offset
|
||||||
|
std::vector<size_t> offset_;
|
||||||
|
// internal label data
|
||||||
|
std::vector<dmlc::real_t> label_;
|
||||||
|
// internal weight data
|
||||||
|
std::vector<dmlc::real_t> weight_;
|
||||||
|
// internal index.
|
||||||
|
std::vector<uint32_t> index_;
|
||||||
|
// internal value.
|
||||||
|
std::vector<dmlc::real_t> value_;
|
||||||
|
// internal Rowblock
|
||||||
|
dmlc::RowBlock<uint32_t> block_;
|
||||||
|
};
|
||||||
|
|
||||||
|
int XGBoostNativeDataIterSetData(
|
||||||
|
void *handle, XGBoostBatchCSR batch) {
|
||||||
|
API_BEGIN();
|
||||||
|
static_cast<xgboost::NativeDataIter*>(handle)->SetData(batch);
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
using namespace xgboost; // NOLINT(*);
|
using namespace xgboost; // NOLINT(*);
|
||||||
@ -95,6 +201,22 @@ int XGDMatrixCreateFromFile(const char *fname,
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int XGDMatrixCreateFromDataIter(
|
||||||
|
void* data_handle,
|
||||||
|
XGBCallbackDataIterNext* callback,
|
||||||
|
const char *cache_info,
|
||||||
|
DMatrixHandle *out) {
|
||||||
|
API_BEGIN();
|
||||||
|
|
||||||
|
std::string scache;
|
||||||
|
if (cache_info != nullptr) {
|
||||||
|
scache = cache_info;
|
||||||
|
}
|
||||||
|
NativeDataIter parser(data_handle, callback);
|
||||||
|
*out = DMatrix::Create(&parser, scache);
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
int XGDMatrixCreateFromCSR(const bst_ulong* indptr,
|
int XGDMatrixCreateFromCSR(const bst_ulong* indptr,
|
||||||
const unsigned *indices,
|
const unsigned *indices,
|
||||||
const float* data,
|
const float* data,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user