Add number of columns to native data iterator. (#5202)
* Change native data iter into an adapter.
This commit is contained in:
@@ -1,18 +1,26 @@
|
||||
/*!
|
||||
* Copyright (c) 2019 by Contributors
|
||||
* Copyright (c) 2019~2020 by Contributors
|
||||
* \file adapter.h
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_ADAPTER_H_
|
||||
#define XGBOOST_DATA_ADAPTER_H_
|
||||
#include <dmlc/data.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/c_api.h"
|
||||
|
||||
#include "../c_api/c_api_error.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
@@ -418,7 +426,7 @@ class FileAdapterBatch {
|
||||
public:
|
||||
class Line {
|
||||
public:
|
||||
Line(size_t row_idx, const uint32_t* feature_idx, const float* value,
|
||||
Line(size_t row_idx, const uint32_t *feature_idx, const float *value,
|
||||
size_t size)
|
||||
: row_idx_(row_idx),
|
||||
feature_idx_(feature_idx),
|
||||
@@ -485,6 +493,112 @@ class FileAdapter : dmlc::DataIter<FileAdapterBatch> {
|
||||
dmlc::Parser<uint32_t>* parser_;
|
||||
};
|
||||
|
||||
/*! \brief Data iterator that takes callback to return data, used in JVM package for
|
||||
* accepting data iterator. */
|
||||
class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
|
||||
public:
|
||||
IteratorAdapter(DataIterHandle data_handle,
|
||||
XGBCallbackDataIterNext* next_callback)
|
||||
: columns_{data::kAdapterUnknownSize}, row_offset_{0},
|
||||
at_first_(true),
|
||||
data_handle_(data_handle), next_callback_(next_callback) {}
|
||||
|
||||
// override functions
|
||||
void BeforeFirst() override {
|
||||
CHECK(at_first_) << "Cannot reset IteratorAdapter";
|
||||
}
|
||||
|
||||
bool Next() override {
|
||||
if ((*next_callback_)(
|
||||
data_handle_,
|
||||
[](void *handle, XGBoostBatchCSR batch) -> int {
|
||||
API_BEGIN();
|
||||
static_cast<IteratorAdapter *>(handle)->SetData(batch);
|
||||
API_END();
|
||||
},
|
||||
this) != 0) {
|
||||
at_first_ = false;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
FileAdapterBatch const& Value() const override {
|
||||
return *batch_.get();
|
||||
}
|
||||
|
||||
// callback to set the data
|
||||
void SetData(const XGBoostBatchCSR& batch) {
|
||||
offset_.clear();
|
||||
label_.clear();
|
||||
weight_.clear();
|
||||
index_.clear();
|
||||
value_.clear();
|
||||
offset_.insert(offset_.end(), batch.offset, batch.offset + batch.size + 1);
|
||||
|
||||
if (batch.label != nullptr) {
|
||||
label_.insert(label_.end(), batch.label, batch.label + batch.size);
|
||||
}
|
||||
if (batch.weight != nullptr) {
|
||||
weight_.insert(weight_.end(), batch.weight, batch.weight + batch.size);
|
||||
}
|
||||
if (batch.index != nullptr) {
|
||||
index_.insert(index_.end(), batch.index + offset_[0],
|
||||
batch.index + offset_.back());
|
||||
}
|
||||
if (batch.value != nullptr) {
|
||||
value_.insert(value_.end(), batch.value + offset_[0],
|
||||
batch.value + offset_.back());
|
||||
}
|
||||
if (offset_[0] != 0) {
|
||||
size_t base = offset_[0];
|
||||
for (size_t &item : offset_) {
|
||||
item -= base;
|
||||
}
|
||||
}
|
||||
CHECK(columns_ == data::kAdapterUnknownSize || columns_ == batch.columns)
|
||||
<< "Number of columns between batches changed from " << columns_
|
||||
<< " to " << batch.columns;
|
||||
|
||||
columns_ = batch.columns;
|
||||
block_.size = batch.size;
|
||||
|
||||
block_.offset = dmlc::BeginPtr(offset_);
|
||||
block_.label = dmlc::BeginPtr(label_);
|
||||
block_.weight = dmlc::BeginPtr(weight_);
|
||||
block_.qid = nullptr;
|
||||
block_.field = nullptr;
|
||||
block_.index = dmlc::BeginPtr(index_);
|
||||
block_.value = dmlc::BeginPtr(value_);
|
||||
|
||||
batch_.reset(new FileAdapterBatch(&block_, row_offset_));
|
||||
row_offset_ += offset_.size() - 1;
|
||||
}
|
||||
|
||||
size_t NumColumns() const { return columns_; }
|
||||
size_t NumRows() const { return kAdapterUnknownSize; }
|
||||
|
||||
private:
|
||||
std::vector<size_t> offset_;
|
||||
std::vector<dmlc::real_t> label_;
|
||||
std::vector<dmlc::real_t> weight_;
|
||||
std::vector<uint32_t> index_;
|
||||
std::vector<dmlc::real_t> value_;
|
||||
|
||||
size_t columns_;
|
||||
size_t row_offset_;
|
||||
// at the beinning.
|
||||
bool at_first_;
|
||||
// handle to the iterator,
|
||||
DataIterHandle data_handle_;
|
||||
// call back to get the data.
|
||||
XGBCallbackDataIterNext *next_callback_;
|
||||
// internal Rowblock
|
||||
dmlc::RowBlock<uint32_t> block_;
|
||||
std::unique_ptr<FileAdapterBatch> batch_;
|
||||
};
|
||||
|
||||
class DMatrixSliceAdapterBatch {
|
||||
public:
|
||||
// Fetch metainfo values according to sliced rows
|
||||
|
||||
Reference in New Issue
Block a user