diff --git a/dmlc-core b/dmlc-core index 71360023d..3f6ff43d3 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 71360023dba458bdc9f1bc6f4309c1a107cb83a0 +Subproject commit 3f6ff43d3976d5b6d5001608b0e3e526ecde098f diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index e950a2765..68f35ace9 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -12,6 +12,9 @@ #endif // XGBoost C API will include APIs in Rabit C API +XGB_EXTERN_C { +#include +} #include #if defined(_MSC_VER) || defined(_WIN32) @@ -26,6 +29,51 @@ typedef unsigned long bst_ulong; // NOLINT(*) typedef void *DMatrixHandle; /*! \brief handle to Booster */ typedef void *BoosterHandle; +/*! \brief handle to a data iterator */ +typedef void *DataIterHandle; +/*! \brief handle to a internal data holder. */ +typedef void *DataHolderHandle; + +/*! \brief Mini batch used in XGBoost Data Iteration */ +typedef struct { + /*! \brief number of rows in the minibatch */ + size_t size; + /*! \brief row pointer to the rows in the data */ + long* offset; // NOLINT(*) + /*! \brief labels of each instance */ + float* label; + /*! \brief weight of each instance, can be NULL */ + float* weight; + /*! \brief feature index */ + int* index; + /*! \brief feature values */ + float* value; +} XGBoostBatchCSR; + + +/*! + * \brief Callback to set the data to handle, + * \param handle The handle to the callback. + * \param batch The data content to be setted. + */ +XGB_EXTERN_C typedef int XGBCallbackSetData( + DataHolderHandle handle, XGBoostBatchCSR batch); + +/*! + * \brief The data reading callback function. + * The iterator will be able to give subset of batch in the data. + * + * If there is data, the function will call set_function to set the data. + * + * \param data_handle The handle to the callback. + * \param set_function The batch returned by the iterator + * \param set_function_handle The handle to be passed to set function. + * \return 0 if we are reaching the end and batch is not returned. + */ +XGB_EXTERN_C typedef int XGBCallbackDataIterNext( + DataIterHandle data_handle, + XGBCallbackSetData* set_function, + DataHolderHandle set_function_handle); /*! * \brief get string message of the last error @@ -50,6 +98,20 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out); +/*! + * \brief Create a DMatrix from a data iterator. + * \param data_handle The handle to the data. + * \param callback The callback to get the data. + * \param cache_info Additional information about cache file, can be null. + * \param out The created DMatrix + * \return 0 when success, -1 when failure happens. + */ +XGB_DLL int XGDMatrixCreateFromDataIter( + DataIterHandle data_handle, + XGBCallbackDataIterNext* callback, + const char* cache_info, + DMatrixHandle *out); + /*! * \brief create a matrix content from csr format * \param indptr pointer to row headers diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DMatrix.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DMatrix.java index f2100b3a3..e2b3ecc47 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DMatrix.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DMatrix.java @@ -16,6 +16,7 @@ package ml.dmlc.xgboost4j; import java.io.IOException; +import java.util.Iterator; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -47,6 +48,33 @@ public class DMatrix { CSC; } + /** + * Create DMatrix from iterator. + * + * @param iter The data iterator of mini batch to provide the data. + * @param cache_info Cache path information, used for external memory setting, can be null. + * @throws XGBoostError + */ + public DMatrix(Iterator iter, String cache_info) throws XGBoostError { + if (iter == null) { + throw new NullPointerException("iter: null"); + } + try { + logger.info(iter.getClass().getMethod("next").toString()); + } catch(NoSuchMethodException e) { + logger.info(e.toString()); + } + long[] out = new long[1]; + JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromDataIter(iter, cache_info, out)); + handle = out[0]; + } + + /** + * Create DMatrix by loading libsvm file from dataPath + * + * @param dataPath The path to the data. + * @throws XGBoostError + */ public DMatrix(String dataPath) throws XGBoostError { if (dataPath == null) { throw new NullPointerException("dataPath: null"); @@ -56,6 +84,14 @@ public class DMatrix { handle = out[0]; } + /** + * Create DMatrix from Sparse matrix in CSR/CSC format. + * @param headers The row index of the matrix. + * @param indices The indices of presenting entries. + * @param data The data content. + * @param st Type of sparsity. + * @throws XGBoostError + */ public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError { long[] out = new long[1]; if (st == SparseType.CSR) { diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DataBatch.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DataBatch.java new file mode 100644 index 000000000..2e48b02f5 --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DataBatch.java @@ -0,0 +1,43 @@ +package ml.dmlc.xgboost4j; + +/** + * A mini-batch of data that can be converted to DMatrix. + * The data is in sparse matrix CSR format. + * + * Usually this object is not needed. + * + * This class is used to support advanced creation of DMatrix from Iterator of DataBatch, + */ +public class DataBatch { + /** The offset of each rows in the sparse matrix */ + long[] rowOffset = null; + /** weight of each data point, can be null */ + float[] weight = null; + /** label of each data point, can be null */ + float[] label = null; + /** index of each feature(column) in the sparse matrix */ + int[] featureIndex = null; + /** value of each non-missing entry in the sparse matrix */ + float[] featureValue = null; + /** + * Get number of rows in the data batch. + * @return Number of rows in the data batch. + */ + public int numRows() { + return rowOffset.length - 1; + } + + /** + * Shallow copy a DataBatch + * @return a copy of the batch + */ + public DataBatch shallowCopy() { + DataBatch b = new DataBatch(); + b.rowOffset = this.rowOffset; + b.weight = this.weight; + b.label = this.label; + b.featureIndex = this.featureIndex; + b.featureValue = this.featureValue; + return b; + } +} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java index 8eded82a7..160396df0 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java @@ -15,6 +15,7 @@ */ package ml.dmlc.xgboost4j; + /** * xgboost JNI functions * change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster @@ -26,6 +27,8 @@ class XgboostJNI { public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out); + public final static native int XGDMatrixCreateFromDataIter(java.util.Iterator iter, String cache_info, long[] out); + public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data, long[] out); diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index da3f5a92d..1c35c8311 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -20,13 +20,124 @@ #include #include -//helper functions -//set handle +// helper functions +// set handle void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) { long out = (long) handle; jenv->SetLongArrayRegion(jhandle, 0, 1, &out); } +// global JVM +static JavaVM* global_jvm = nullptr; + +// overrides JNI on load +jint JNI_OnLoad(JavaVM *vm, void *reserved) { + global_jvm = vm; + return JNI_VERSION_1_6; +} + +XGB_EXTERN_C int XGBoost4jCallbackDataIterNext( + DataIterHandle data_handle, + XGBCallbackSetData* set_function, + DataHolderHandle set_function_handle) { + jobject jiter = static_cast(data_handle); + JNIEnv* jenv; + int jni_status = global_jvm->GetEnv((void **)&jenv, JNI_VERSION_1_6); + if (jni_status == JNI_EDETACHED) { + global_jvm->AttachCurrentThread(reinterpret_cast(&jenv), nullptr); + } else { + CHECK(jni_status == JNI_OK); + } + try { + jclass iterClass = jenv->FindClass("java/util/Iterator"); + jmethodID hasNext = jenv->GetMethodID(iterClass, + "hasNext", "()Z"); + jmethodID next = jenv->GetMethodID(iterClass, + "next", "()Ljava/lang/Object;"); + int ret_value; + if (jenv->CallBooleanMethod(jiter, hasNext)) { + ret_value = 1; + jobject batch = jenv->CallObjectMethod(jiter, next); + jclass batchClass = jenv->GetObjectClass(batch); + jlongArray joffset = (jlongArray)jenv->GetObjectField( + batch, jenv->GetFieldID(batchClass, "rowOffset", "[J")); + jfloatArray jlabel = (jfloatArray)jenv->GetObjectField( + batch, jenv->GetFieldID(batchClass, "label", "[F")); + jfloatArray jweight = (jfloatArray)jenv->GetObjectField( + batch, jenv->GetFieldID(batchClass, "weight", "[F")); + jintArray jindex = (jintArray)jenv->GetObjectField( + batch, jenv->GetFieldID(batchClass, "featureIndex", "[I")); + jfloatArray jvalue = (jfloatArray)jenv->GetObjectField( + batch, jenv->GetFieldID(batchClass, "featureValue", "[F")); + XGBoostBatchCSR cbatch; + cbatch.size = jenv->GetArrayLength(joffset) - 1; + cbatch.offset = jenv->GetLongArrayElements(joffset, 0); + if (jlabel != nullptr) { + cbatch.label = jenv->GetFloatArrayElements(jlabel, 0); + CHECK_EQ(jenv->GetArrayLength(jlabel), static_cast(cbatch.size)) + << "batch.label.length must equal batch.numRows()"; + } else { + cbatch.label = nullptr; + } + if (jweight != nullptr) { + cbatch.weight = jenv->GetFloatArrayElements(jweight, 0); + CHECK_EQ(jenv->GetArrayLength(jweight), static_cast(cbatch.size)) + << "batch.weight.length must equal batch.numRows()"; + } else { + cbatch.weight = nullptr; + } + long max_elem = cbatch.offset[cbatch.size]; + cbatch.index = jenv->GetIntArrayElements(jindex, 0); + cbatch.value = jenv->GetFloatArrayElements(jvalue, 0); + CHECK_EQ(jenv->GetArrayLength(jindex), max_elem) + << "batch.index.length must equal batch.offset.back()"; + CHECK_EQ(jenv->GetArrayLength(jvalue), max_elem) + << "batch.index.length must equal batch.offset.back()"; + // cbatch is ready + CHECK_EQ((*set_function)(set_function_handle, cbatch), 0) + << XGBGetLastError(); + // release the elements. + jenv->ReleaseLongArrayElements(joffset, cbatch.offset, 0); + jenv->DeleteLocalRef(joffset); + if (jlabel != nullptr) { + jenv->ReleaseFloatArrayElements(jlabel, cbatch.label, 0); + jenv->DeleteLocalRef(jlabel); + } + if (jweight != nullptr) { + jenv->ReleaseFloatArrayElements(jweight, cbatch.weight, 0); + jenv->DeleteLocalRef(jweight); + } + jenv->ReleaseIntArrayElements(jindex, cbatch.index, 0); + jenv->DeleteLocalRef(jindex); + jenv->ReleaseFloatArrayElements(jvalue, cbatch.value, 0); + jenv->DeleteLocalRef(jvalue); + jenv->DeleteLocalRef(batch); + jenv->DeleteLocalRef(batchClass); + ret_value = 1; + } else { + ret_value = 0; + } + jenv->DeleteLocalRef(iterClass); + // only detach if it is a async call. + if (jni_status == JNI_EDETACHED) { + global_jvm->DetachCurrentThread(); + } + return ret_value; + } catch(dmlc::Error e) { + // only detach if it is a async call. + if (jni_status == JNI_EDETACHED) { + global_jvm->DetachCurrentThread(); + } + LOG(FATAL) << e.what(); + return -1; + } +} + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGBGetLastError + * Signature: ()Ljava/lang/String; + */ JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError (JNIEnv *jenv, jclass jcls) { jstring jresult = 0; @@ -37,6 +148,32 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError return jresult; } +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGDMatrixCreateFromDataIter + * Signature: (Ljava/util/Iterator;Ljava/lang/String;[J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromDataIter + (JNIEnv *jenv, jclass jcls, jobject jiter, jstring jcache_info, jlongArray jout) { + DMatrixHandle result; + const char* cache_info = nullptr; + if (jcache_info != nullptr) { + cache_info = jenv->GetStringUTFChars(jcache_info, 0); + } + int ret = XGDMatrixCreateFromDataIter( + jiter, XGBoost4jCallbackDataIterNext, cache_info, &result); + if (cache_info) { + jenv->ReleaseStringUTFChars(jcache_info, cache_info); + } + setHandle(jenv, jout, result); + return ret; +} + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGDMatrixCreateFromFile + * Signature: (Ljava/lang/String;I[J)I + */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile (JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) { DMatrixHandle result; diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index 6d811ad88..0a3eeae3a 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -23,6 +23,14 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile (JNIEnv *, jclass, jstring, jint, jlongArray); +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGDMatrixCreateFromDataIter + * Signature: (Ljava/util/Iterator;Ljava/lang/String;[J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromDataIter + (JNIEnv *, jclass, jobject, jstring, jlongArray); + /* * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGDMatrixCreateFromCSR diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/DMatrixTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/DMatrixTest.java index 8868d1600..6c8206940 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/DMatrixTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/DMatrixTest.java @@ -28,6 +28,41 @@ import org.junit.Test; */ public class DMatrixTest { + @Test + public void testCreateFromDataIterator() throws XGBoostError { + //create DMatrix from DataIterator + /** + * sparse matrix + * 1 0 2 3 0 + * 4 0 2 3 5 + * 3 1 2 5 0 + */ + DataBatch batch = new DataBatch(); + batch.featureIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3}; + batch.featureValue = new float[]{1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5}; + batch.rowOffset = new long[]{0, 3, 7, 11}; + batch.label = new float[] {0.1f, 0.2f, 0.3f}; + java.util.ArrayList labelall = new java.util.ArrayList(); + int nrep = 3; + java.util.List blist = new java.util.LinkedList(); + for (int i = 0; i < nrep; ++i) { + batch.label = new float[] {0.1f+i, 0.2f+i, 0.3f+i}; + blist.add(batch.shallowCopy()); + for (float f : batch.label) { + labelall.add(f); + } + } + DMatrix dmat = new DMatrix(blist.iterator(), null); + // get label + float[] labels = dmat.getLabel(); + // get label + TestCase.assertTrue(batch.label.length * nrep == labels.length); + + for (int i = 0; i < labels.length; ++i) { + TestCase.assertTrue(labelall.get(i) == labels[i]); + } + } + @Test public void testCreateFromFile() throws XGBoostError { //create DMatrix from file diff --git a/rabit b/rabit index 1392e9f3d..be50e7b63 160000 --- a/rabit +++ b/rabit @@ -1 +1 @@ -Subproject commit 1392e9f3da59bd5602ddebee944dd8fb5c6507b0 +Subproject commit be50e7b63224b9fb7ff94ce34df9f8752ef83043 diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index b543f9b6e..d0ea7815b 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -19,7 +19,6 @@ #include "../common/group_data.h" namespace xgboost { - // booster wrapper for backward compatible reason. class Booster { public: @@ -61,6 +60,113 @@ class Booster { std::unique_ptr learner_; std::vector > cfg_; }; + +// declare the data callback. +XGB_EXTERN_C int XGBoostNativeDataIterSetData( + void *handle, XGBoostBatchCSR batch); + +/*! \brief Native data iterator that takes callback to return data */ +class NativeDataIter : public dmlc::Parser { + public: + NativeDataIter(DataIterHandle data_handle, + XGBCallbackDataIterNext* next_callback) + : at_first_(true), bytes_read_(0), + data_handle_(data_handle), next_callback_(next_callback) { + } + + // override functions + void BeforeFirst() override { + CHECK(at_first_) << "cannot reset NativeDataIter"; + } + + bool Next() override { + if ((*next_callback_)( + data_handle_, + XGBoostNativeDataIterSetData, + this) != 0) { + at_first_ = false; + return true; + } else { + return false; + } + } + + const dmlc::RowBlock& Value() const override { + return block_; + } + + size_t BytesRead() const override { + return bytes_read_; + } + + // callback to set the data + void SetData(const XGBoostBatchCSR& batch) { + offset_.clear(); + label_.clear(); + weight_.clear(); + index_.clear(); + value_.clear(); + offset_.insert(offset_.end(), batch.offset, batch.offset + batch.size + 1); + if (batch.label != nullptr) { + label_.insert(label_.end(), batch.label, batch.label + batch.size); + } + if (batch.weight != nullptr) { + weight_.insert(weight_.end(), batch.weight, batch.weight + batch.size); + } + if (batch.index != nullptr) { + index_.insert(index_.end(), batch.index + offset_[0], batch.index + offset_.back()); + } + if (batch.value != nullptr) { + value_.insert(value_.end(), batch.value + offset_[0], batch.value + offset_.back()); + } + if (offset_[0] != 0) { + size_t base = offset_[0]; + for (size_t& item : offset_) { + item -= base; + } + } + block_.size = batch.size; + block_.offset = dmlc::BeginPtr(offset_); + block_.label = dmlc::BeginPtr(label_); + block_.weight = dmlc::BeginPtr(weight_); + block_.index = dmlc::BeginPtr(index_); + block_.value = dmlc::BeginPtr(value_); + bytes_read_ += offset_.size() * sizeof(size_t) + + label_.size() * sizeof(dmlc::real_t) + + weight_.size() * sizeof(dmlc::real_t) + + index_.size() * sizeof(uint32_t) + + value_.size() * sizeof(dmlc::real_t); + } + + private: + // at the beinning. + bool at_first_; + // bytes that is read. + size_t bytes_read_; + // handle to the iterator, + DataIterHandle data_handle_; + // call back to get the data. + XGBCallbackDataIterNext* next_callback_; + // internal offset + std::vector offset_; + // internal label data + std::vector label_; + // internal weight data + std::vector weight_; + // internal index. + std::vector index_; + // internal value. + std::vector value_; + // internal Rowblock + dmlc::RowBlock block_; +}; + +int XGBoostNativeDataIterSetData( + void *handle, XGBoostBatchCSR batch) { + API_BEGIN(); + static_cast(handle)->SetData(batch); + API_END(); +} } // namespace xgboost using namespace xgboost; // NOLINT(*); @@ -95,6 +201,22 @@ int XGDMatrixCreateFromFile(const char *fname, API_END(); } +int XGDMatrixCreateFromDataIter( + void* data_handle, + XGBCallbackDataIterNext* callback, + const char *cache_info, + DMatrixHandle *out) { + API_BEGIN(); + + std::string scache; + if (cache_info != nullptr) { + scache = cache_info; + } + NativeDataIter parser(data_handle, callback); + *out = DMatrix::Create(&parser, scache); + API_END(); +} + int XGDMatrixCreateFromCSR(const bst_ulong* indptr, const unsigned *indices, const float* data,