Add number of columns to native data iterator. (#5202)

* Change native data iter into an adapter.
This commit is contained in:
Jiaming Yuan
2020-02-25 23:42:01 +08:00
committed by GitHub
parent e0509b3307
commit f2b8cd2922
11 changed files with 244 additions and 156 deletions

View File

@@ -1,18 +1,26 @@
/*!
* Copyright (c) 2019 by Contributors
* Copyright (c) 2019~2020 by Contributors
* \file adapter.h
*/
#ifndef XGBOOST_DATA_ADAPTER_H_
#define XGBOOST_DATA_ADAPTER_H_
#include <dmlc/data.h>
#include <cstddef>
#include <functional>
#include <limits>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "xgboost/logging.h"
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/span.h"
#include "xgboost/c_api.h"
#include "../c_api/c_api_error.h"
namespace xgboost {
namespace data {
@@ -418,7 +426,7 @@ class FileAdapterBatch {
public:
class Line {
public:
Line(size_t row_idx, const uint32_t* feature_idx, const float* value,
Line(size_t row_idx, const uint32_t *feature_idx, const float *value,
size_t size)
: row_idx_(row_idx),
feature_idx_(feature_idx),
@@ -485,6 +493,112 @@ class FileAdapter : dmlc::DataIter<FileAdapterBatch> {
dmlc::Parser<uint32_t>* parser_;
};
/*! \brief Data iterator that takes callback to return data, used in JVM package for
* accepting data iterator. */
class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
public:
IteratorAdapter(DataIterHandle data_handle,
XGBCallbackDataIterNext* next_callback)
: columns_{data::kAdapterUnknownSize}, row_offset_{0},
at_first_(true),
data_handle_(data_handle), next_callback_(next_callback) {}
// override functions
void BeforeFirst() override {
CHECK(at_first_) << "Cannot reset IteratorAdapter";
}
bool Next() override {
if ((*next_callback_)(
data_handle_,
[](void *handle, XGBoostBatchCSR batch) -> int {
API_BEGIN();
static_cast<IteratorAdapter *>(handle)->SetData(batch);
API_END();
},
this) != 0) {
at_first_ = false;
return true;
} else {
return false;
}
}
FileAdapterBatch const& Value() const override {
return *batch_.get();
}
// callback to set the data
void SetData(const XGBoostBatchCSR& batch) {
offset_.clear();
label_.clear();
weight_.clear();
index_.clear();
value_.clear();
offset_.insert(offset_.end(), batch.offset, batch.offset + batch.size + 1);
if (batch.label != nullptr) {
label_.insert(label_.end(), batch.label, batch.label + batch.size);
}
if (batch.weight != nullptr) {
weight_.insert(weight_.end(), batch.weight, batch.weight + batch.size);
}
if (batch.index != nullptr) {
index_.insert(index_.end(), batch.index + offset_[0],
batch.index + offset_.back());
}
if (batch.value != nullptr) {
value_.insert(value_.end(), batch.value + offset_[0],
batch.value + offset_.back());
}
if (offset_[0] != 0) {
size_t base = offset_[0];
for (size_t &item : offset_) {
item -= base;
}
}
CHECK(columns_ == data::kAdapterUnknownSize || columns_ == batch.columns)
<< "Number of columns between batches changed from " << columns_
<< " to " << batch.columns;
columns_ = batch.columns;
block_.size = batch.size;
block_.offset = dmlc::BeginPtr(offset_);
block_.label = dmlc::BeginPtr(label_);
block_.weight = dmlc::BeginPtr(weight_);
block_.qid = nullptr;
block_.field = nullptr;
block_.index = dmlc::BeginPtr(index_);
block_.value = dmlc::BeginPtr(value_);
batch_.reset(new FileAdapterBatch(&block_, row_offset_));
row_offset_ += offset_.size() - 1;
}
size_t NumColumns() const { return columns_; }
size_t NumRows() const { return kAdapterUnknownSize; }
private:
std::vector<size_t> offset_;
std::vector<dmlc::real_t> label_;
std::vector<dmlc::real_t> weight_;
std::vector<uint32_t> index_;
std::vector<dmlc::real_t> value_;
size_t columns_;
size_t row_offset_;
// at the beinning.
bool at_first_;
// handle to the iterator,
DataIterHandle data_handle_;
// call back to get the data.
XGBCallbackDataIterNext *next_callback_;
// internal Rowblock
dmlc::RowBlock<uint32_t> block_;
std::unique_ptr<FileAdapterBatch> batch_;
};
class DMatrixSliceAdapterBatch {
public:
// Fetch metainfo values according to sliced rows