[JVM] Add Iterator loading API
This commit is contained in:
@@ -19,7 +19,6 @@
|
||||
#include "../common/group_data.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
// booster wrapper for backward compatible reason.
|
||||
class Booster {
|
||||
public:
|
||||
@@ -61,6 +60,113 @@ class Booster {
|
||||
std::unique_ptr<Learner> learner_;
|
||||
std::vector<std::pair<std::string, std::string> > cfg_;
|
||||
};
|
||||
|
||||
// declare the data callback.
|
||||
XGB_EXTERN_C int XGBoostNativeDataIterSetData(
|
||||
void *handle, XGBoostBatchCSR batch);
|
||||
|
||||
/*! \brief Native data iterator that takes callback to return data */
|
||||
class NativeDataIter : public dmlc::Parser<uint32_t> {
|
||||
public:
|
||||
NativeDataIter(DataIterHandle data_handle,
|
||||
XGBCallbackDataIterNext* next_callback)
|
||||
: at_first_(true), bytes_read_(0),
|
||||
data_handle_(data_handle), next_callback_(next_callback) {
|
||||
}
|
||||
|
||||
// override functions
|
||||
void BeforeFirst() override {
|
||||
CHECK(at_first_) << "cannot reset NativeDataIter";
|
||||
}
|
||||
|
||||
bool Next() override {
|
||||
if ((*next_callback_)(
|
||||
data_handle_,
|
||||
XGBoostNativeDataIterSetData,
|
||||
this) != 0) {
|
||||
at_first_ = false;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const dmlc::RowBlock<uint32_t>& Value() const override {
|
||||
return block_;
|
||||
}
|
||||
|
||||
size_t BytesRead() const override {
|
||||
return bytes_read_;
|
||||
}
|
||||
|
||||
// callback to set the data
|
||||
void SetData(const XGBoostBatchCSR& batch) {
|
||||
offset_.clear();
|
||||
label_.clear();
|
||||
weight_.clear();
|
||||
index_.clear();
|
||||
value_.clear();
|
||||
offset_.insert(offset_.end(), batch.offset, batch.offset + batch.size + 1);
|
||||
if (batch.label != nullptr) {
|
||||
label_.insert(label_.end(), batch.label, batch.label + batch.size);
|
||||
}
|
||||
if (batch.weight != nullptr) {
|
||||
weight_.insert(weight_.end(), batch.weight, batch.weight + batch.size);
|
||||
}
|
||||
if (batch.index != nullptr) {
|
||||
index_.insert(index_.end(), batch.index + offset_[0], batch.index + offset_.back());
|
||||
}
|
||||
if (batch.value != nullptr) {
|
||||
value_.insert(value_.end(), batch.value + offset_[0], batch.value + offset_.back());
|
||||
}
|
||||
if (offset_[0] != 0) {
|
||||
size_t base = offset_[0];
|
||||
for (size_t& item : offset_) {
|
||||
item -= base;
|
||||
}
|
||||
}
|
||||
block_.size = batch.size;
|
||||
block_.offset = dmlc::BeginPtr(offset_);
|
||||
block_.label = dmlc::BeginPtr(label_);
|
||||
block_.weight = dmlc::BeginPtr(weight_);
|
||||
block_.index = dmlc::BeginPtr(index_);
|
||||
block_.value = dmlc::BeginPtr(value_);
|
||||
bytes_read_ += offset_.size() * sizeof(size_t) +
|
||||
label_.size() * sizeof(dmlc::real_t) +
|
||||
weight_.size() * sizeof(dmlc::real_t) +
|
||||
index_.size() * sizeof(uint32_t) +
|
||||
value_.size() * sizeof(dmlc::real_t);
|
||||
}
|
||||
|
||||
private:
|
||||
// at the beinning.
|
||||
bool at_first_;
|
||||
// bytes that is read.
|
||||
size_t bytes_read_;
|
||||
// handle to the iterator,
|
||||
DataIterHandle data_handle_;
|
||||
// call back to get the data.
|
||||
XGBCallbackDataIterNext* next_callback_;
|
||||
// internal offset
|
||||
std::vector<size_t> offset_;
|
||||
// internal label data
|
||||
std::vector<dmlc::real_t> label_;
|
||||
// internal weight data
|
||||
std::vector<dmlc::real_t> weight_;
|
||||
// internal index.
|
||||
std::vector<uint32_t> index_;
|
||||
// internal value.
|
||||
std::vector<dmlc::real_t> value_;
|
||||
// internal Rowblock
|
||||
dmlc::RowBlock<uint32_t> block_;
|
||||
};
|
||||
|
||||
int XGBoostNativeDataIterSetData(
|
||||
void *handle, XGBoostBatchCSR batch) {
|
||||
API_BEGIN();
|
||||
static_cast<xgboost::NativeDataIter*>(handle)->SetData(batch);
|
||||
API_END();
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
using namespace xgboost; // NOLINT(*);
|
||||
@@ -95,6 +201,22 @@ int XGDMatrixCreateFromFile(const char *fname,
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGDMatrixCreateFromDataIter(
|
||||
void* data_handle,
|
||||
XGBCallbackDataIterNext* callback,
|
||||
const char *cache_info,
|
||||
DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
|
||||
std::string scache;
|
||||
if (cache_info != nullptr) {
|
||||
scache = cache_info;
|
||||
}
|
||||
NativeDataIter parser(data_handle, callback);
|
||||
*out = DMatrix::Create(&parser, scache);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGDMatrixCreateFromCSR(const bst_ulong* indptr,
|
||||
const unsigned *indices,
|
||||
const float* data,
|
||||
|
||||
Reference in New Issue
Block a user