Support building SimpleDMatrix from Arrow data format (#7512)
* Integrate with Arrow C data API. * Support Arrow dataset. * Support Arrow table. Co-authored-by: Xiaochang Wu <xiaochang.wu@intel.com> Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com> Co-authored-by: Zhang Zhang <zhang.zhang@intel.com>
This commit is contained in:
@@ -13,6 +13,8 @@
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/base.h"
|
||||
@@ -22,6 +24,7 @@
|
||||
#include "array_interface.h"
|
||||
#include "../c_api/c_api_error.h"
|
||||
#include "../common/math.h"
|
||||
#include "arrow-cdi.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
@@ -676,11 +679,10 @@ class FileAdapter : dmlc::DataIter<FileAdapterBatch> {
|
||||
template <typename DataIterHandle, typename XGBCallbackDataIterNext, typename XGBoostBatchCSR>
|
||||
class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
|
||||
public:
|
||||
IteratorAdapter(DataIterHandle data_handle,
|
||||
XGBCallbackDataIterNext* next_callback)
|
||||
: columns_{data::kAdapterUnknownSize}, row_offset_{0},
|
||||
at_first_(true),
|
||||
data_handle_(data_handle), next_callback_(next_callback) {}
|
||||
IteratorAdapter(DataIterHandle data_handle, XGBCallbackDataIterNext* next_callback)
|
||||
: columns_{data::kAdapterUnknownSize},
|
||||
data_handle_(data_handle),
|
||||
next_callback_(next_callback) {}
|
||||
|
||||
// override functions
|
||||
void BeforeFirst() override {
|
||||
@@ -766,9 +768,9 @@ class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
|
||||
std::vector<dmlc::real_t> value_;
|
||||
|
||||
size_t columns_;
|
||||
size_t row_offset_;
|
||||
size_t row_offset_{0};
|
||||
// at the beginning.
|
||||
bool at_first_;
|
||||
bool at_first_{true};
|
||||
// handle to the iterator,
|
||||
DataIterHandle data_handle_;
|
||||
// call back to get the data.
|
||||
@@ -777,6 +779,358 @@ class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
|
||||
dmlc::RowBlock<uint32_t> block_;
|
||||
std::unique_ptr<FileAdapterBatch> batch_;
|
||||
};
|
||||
|
||||
enum ColumnDType : uint8_t {
|
||||
kUnknown,
|
||||
kInt8,
|
||||
kUInt8,
|
||||
kInt16,
|
||||
kUInt16,
|
||||
kInt32,
|
||||
kUInt32,
|
||||
kInt64,
|
||||
kUInt64,
|
||||
kFloat,
|
||||
kDouble
|
||||
};
|
||||
|
||||
class Column {
|
||||
public:
|
||||
Column() = default;
|
||||
|
||||
Column(size_t col_idx, size_t length, size_t null_count, const uint8_t* bitmap)
|
||||
: col_idx_{col_idx}, length_{length}, null_count_{null_count}, bitmap_{bitmap} {}
|
||||
|
||||
virtual ~Column() = default;
|
||||
|
||||
Column(const Column&) = delete;
|
||||
Column& operator=(const Column&) = delete;
|
||||
Column(Column&&) = delete;
|
||||
Column& operator=(Column&&) = delete;
|
||||
|
||||
// whether the valid bit is set for this element
|
||||
bool IsValid(size_t row_idx) const {
|
||||
return (!bitmap_ || (bitmap_[row_idx/8] & (1 << (row_idx%8))));
|
||||
}
|
||||
|
||||
virtual COOTuple GetElement(size_t row_idx) const = 0;
|
||||
|
||||
virtual bool IsValidElement(size_t row_idx) const = 0;
|
||||
|
||||
virtual std::vector<float> AsFloatVector() const = 0;
|
||||
|
||||
virtual std::vector<uint64_t> AsUint64Vector() const = 0;
|
||||
|
||||
size_t Length() const { return length_; }
|
||||
|
||||
protected:
|
||||
size_t col_idx_;
|
||||
size_t length_;
|
||||
size_t null_count_;
|
||||
const uint8_t* bitmap_;
|
||||
};
|
||||
|
||||
// Only columns of primitive types are supported. An ArrowColumnarBatch is a
|
||||
// collection of std::shared_ptr<PrimitiveColumn>. These columns can be of different data types.
|
||||
// Hence, PrimitiveColumn is a class template; and all concrete PrimitiveColumns
|
||||
// derive from the abstract class Column.
|
||||
template <typename T>
|
||||
class PrimitiveColumn : public Column {
|
||||
static constexpr float kNaN = std::numeric_limits<float>::quiet_NaN();
|
||||
|
||||
public:
|
||||
PrimitiveColumn(size_t idx, size_t length, size_t null_count,
|
||||
const uint8_t* bitmap, const T* data, float missing)
|
||||
: Column{idx, length, null_count, bitmap}, data_{data}, missing_{missing} {}
|
||||
|
||||
COOTuple GetElement(size_t row_idx) const override {
|
||||
CHECK(data_ && row_idx < length_) << "Column is empty or out-of-bound index of the column";
|
||||
return { row_idx, col_idx_, IsValidElement(row_idx) ?
|
||||
static_cast<float>(data_[row_idx]) : kNaN };
|
||||
}
|
||||
|
||||
bool IsValidElement(size_t row_idx) const override {
|
||||
// std::isfinite needs to cast to double to prevent msvc report error
|
||||
return IsValid(row_idx)
|
||||
&& std::isfinite(static_cast<double>(data_[row_idx]))
|
||||
&& static_cast<float>(data_[row_idx]) != missing_;
|
||||
}
|
||||
|
||||
std::vector<float> AsFloatVector() const override {
|
||||
CHECK(data_) << "Column is empty";
|
||||
std::vector<float> fv(length_);
|
||||
std::transform(data_, data_ + length_, fv.begin(),
|
||||
[](T v) { return static_cast<float>(v); });
|
||||
return fv;
|
||||
}
|
||||
|
||||
std::vector<uint64_t> AsUint64Vector() const override {
|
||||
CHECK(data_) << "Column is empty";
|
||||
std::vector<uint64_t> iv(length_);
|
||||
std::transform(data_, data_ + length_, iv.begin(),
|
||||
[](T v) { return static_cast<uint64_t>(v); });
|
||||
return iv;
|
||||
}
|
||||
|
||||
private:
|
||||
const T* data_;
|
||||
float missing_; // user specified missing value
|
||||
};
|
||||
|
||||
struct ColumnarMetaInfo {
|
||||
// data type of the column
|
||||
ColumnDType type{ColumnDType::kUnknown};
|
||||
// location of the column in an Arrow record batch
|
||||
int64_t loc{-1};
|
||||
};
|
||||
|
||||
struct ArrowSchemaImporter {
|
||||
std::vector<ColumnarMetaInfo> columns;
|
||||
|
||||
// map Arrow format strings to types
|
||||
static ColumnDType FormatMap(char const* format_str) {
|
||||
CHECK(format_str) << "Format string cannot be empty";
|
||||
switch (format_str[0]) {
|
||||
case 'c':
|
||||
return ColumnDType::kInt8;
|
||||
case 'C':
|
||||
return ColumnDType::kUInt8;
|
||||
case 's':
|
||||
return ColumnDType::kInt16;
|
||||
case 'S':
|
||||
return ColumnDType::kUInt16;
|
||||
case 'i':
|
||||
return ColumnDType::kInt32;
|
||||
case 'I':
|
||||
return ColumnDType::kUInt32;
|
||||
case 'l':
|
||||
return ColumnDType::kInt64;
|
||||
case 'L':
|
||||
return ColumnDType::kUInt64;
|
||||
case 'f':
|
||||
return ColumnDType::kFloat;
|
||||
case 'g':
|
||||
return ColumnDType::kDouble;
|
||||
default:
|
||||
CHECK(false) << "Column data type not supported by XGBoost";
|
||||
return ColumnDType::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
void Import(struct ArrowSchema *schema) {
|
||||
if (schema) {
|
||||
CHECK(std::string(schema->format) == "+s"); // NOLINT
|
||||
CHECK(columns.empty());
|
||||
for (auto i = 0; i < schema->n_children; ++i) {
|
||||
std::string name{schema->children[i]->name};
|
||||
ColumnDType type = FormatMap(schema->children[i]->format);
|
||||
ColumnarMetaInfo col_info{type, i};
|
||||
columns.push_back(col_info);
|
||||
}
|
||||
if (schema->release) {
|
||||
schema->release(schema);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class ArrowColumnarBatch {
|
||||
public:
|
||||
ArrowColumnarBatch(struct ArrowArray *rb, struct ArrowSchemaImporter* schema)
|
||||
: rb_{rb}, schema_{schema} {
|
||||
CHECK(rb_) << "Cannot import non-existent record batch";
|
||||
CHECK(!schema_->columns.empty()) << "Cannot import record batch without a schema";
|
||||
}
|
||||
|
||||
size_t Import(float missing) {
|
||||
auto& infov = schema_->columns;
|
||||
for (size_t i = 0; i < infov.size(); ++i) {
|
||||
columns_.push_back(CreateColumn(i, infov[i], missing));
|
||||
}
|
||||
|
||||
// Compute the starting location for every row in this batch
|
||||
auto batch_size = rb_->length;
|
||||
auto num_columns = columns_.size();
|
||||
row_offsets_.resize(batch_size + 1, 0);
|
||||
for (auto i = 0; i < batch_size; ++i) {
|
||||
row_offsets_[i+1] = row_offsets_[i];
|
||||
for (size_t j = 0; j < num_columns; ++j) {
|
||||
if (GetColumn(j).IsValidElement(i)) {
|
||||
row_offsets_[i+1]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
// return number of elements in the batch
|
||||
return row_offsets_.back();
|
||||
}
|
||||
|
||||
ArrowColumnarBatch(const ArrowColumnarBatch&) = delete;
|
||||
ArrowColumnarBatch& operator=(const ArrowColumnarBatch&) = delete;
|
||||
ArrowColumnarBatch(ArrowColumnarBatch&&) = delete;
|
||||
ArrowColumnarBatch& operator=(ArrowColumnarBatch&&) = delete;
|
||||
|
||||
virtual ~ArrowColumnarBatch() {
|
||||
if (rb_ && rb_->release) {
|
||||
rb_->release(rb_);
|
||||
rb_ = nullptr;
|
||||
}
|
||||
columns_.clear();
|
||||
}
|
||||
|
||||
size_t Size() const { return rb_ ? rb_->length : 0; }
|
||||
|
||||
size_t NumColumns() const { return columns_.size(); }
|
||||
|
||||
size_t NumElements() const { return row_offsets_.back(); }
|
||||
|
||||
const Column& GetColumn(size_t col_idx) const {
|
||||
return *columns_[col_idx];
|
||||
}
|
||||
|
||||
void ShiftRowOffsets(size_t batch_offset) {
|
||||
std::transform(row_offsets_.begin(), row_offsets_.end(), row_offsets_.begin(),
|
||||
[=](size_t c) { return c + batch_offset; });
|
||||
}
|
||||
|
||||
const std::vector<size_t>& RowOffsets() const { return row_offsets_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Column> CreateColumn(size_t idx,
|
||||
ColumnarMetaInfo info,
|
||||
float missing) const {
|
||||
if (info.loc < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto loc_in_batch = info.loc;
|
||||
auto length = rb_->length;
|
||||
auto null_count = rb_->null_count;
|
||||
auto buffers0 = rb_->children[loc_in_batch]->buffers[0];
|
||||
auto buffers1 = rb_->children[loc_in_batch]->buffers[1];
|
||||
const uint8_t* bitmap = buffers0 ? reinterpret_cast<const uint8_t*>(buffers0) : nullptr;
|
||||
const uint8_t* data = buffers1 ? reinterpret_cast<const uint8_t*>(buffers1) : nullptr;
|
||||
|
||||
// if null_count is not computed, compute it here
|
||||
if (null_count < 0) {
|
||||
if (!bitmap) {
|
||||
null_count = 0;
|
||||
} else {
|
||||
null_count = length;
|
||||
for (auto i = 0; i < length; ++i) {
|
||||
if (bitmap[i/8] & (1 << (i%8))) {
|
||||
null_count--;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch (info.type) {
|
||||
case ColumnDType::kInt8:
|
||||
return std::make_shared<PrimitiveColumn<int8_t>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const int8_t*>(data), missing);
|
||||
case ColumnDType::kUInt8:
|
||||
return std::make_shared<PrimitiveColumn<uint8_t>>(
|
||||
idx, length, null_count, bitmap, data, missing);
|
||||
case ColumnDType::kInt16:
|
||||
return std::make_shared<PrimitiveColumn<int16_t>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const int16_t*>(data), missing);
|
||||
case ColumnDType::kUInt16:
|
||||
return std::make_shared<PrimitiveColumn<uint16_t>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const uint16_t*>(data), missing);
|
||||
case ColumnDType::kInt32:
|
||||
return std::make_shared<PrimitiveColumn<int32_t>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const int32_t*>(data), missing);
|
||||
case ColumnDType::kUInt32:
|
||||
return std::make_shared<PrimitiveColumn<uint32_t>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const uint32_t*>(data), missing);
|
||||
case ColumnDType::kInt64:
|
||||
return std::make_shared<PrimitiveColumn<int64_t>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const int64_t*>(data), missing);
|
||||
case ColumnDType::kUInt64:
|
||||
return std::make_shared<PrimitiveColumn<uint64_t>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const uint64_t*>(data), missing);
|
||||
case ColumnDType::kFloat:
|
||||
return std::make_shared<PrimitiveColumn<float>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const float*>(data), missing);
|
||||
case ColumnDType::kDouble:
|
||||
return std::make_shared<PrimitiveColumn<double>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const double*>(data), missing);
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
struct ArrowArray* rb_;
|
||||
struct ArrowSchemaImporter* schema_;
|
||||
std::vector<std::shared_ptr<Column>> columns_;
|
||||
std::vector<size_t> row_offsets_;
|
||||
};
|
||||
|
||||
using ArrowColumnarBatchVec = std::vector<std::unique_ptr<ArrowColumnarBatch>>;
|
||||
class RecordBatchesIterAdapter: public dmlc::DataIter<ArrowColumnarBatchVec> {
|
||||
public:
|
||||
RecordBatchesIterAdapter(XGDMatrixCallbackNext *next_callback,
|
||||
int nthread)
|
||||
: next_callback_{next_callback},
|
||||
nbatches_{nthread} {}
|
||||
|
||||
void BeforeFirst() override {
|
||||
CHECK(at_first_) << "Cannot reset RecordBatchesIterAdapter";
|
||||
}
|
||||
|
||||
bool Next() override {
|
||||
batches_.clear();
|
||||
while (batches_.size() < static_cast<size_t>(nbatches_) && (*next_callback_)(this) != 0) {
|
||||
at_first_ = false;
|
||||
}
|
||||
|
||||
if (batches_.size() > 0) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void SetData(struct ArrowArray* rb, struct ArrowSchema* schema) {
|
||||
// Schema is only imported once at the beginning, regardless how many
|
||||
// baches are comming.
|
||||
// But even schema is not imported we still need to release its C data
|
||||
// exported from Arrow.
|
||||
if (at_first_ && schema) {
|
||||
schema_.Import(schema);
|
||||
} else {
|
||||
if (schema && schema->release) {
|
||||
schema->release(schema);
|
||||
}
|
||||
}
|
||||
if (rb) {
|
||||
batches_.push_back(std::make_unique<ArrowColumnarBatch>(rb, &schema_));
|
||||
}
|
||||
}
|
||||
|
||||
const ArrowColumnarBatchVec& Value() const override {
|
||||
return batches_;
|
||||
}
|
||||
|
||||
size_t NumColumns() const { return schema_.columns.size(); }
|
||||
size_t NumRows() const { return kAdapterUnknownSize; }
|
||||
|
||||
private:
|
||||
XGDMatrixCallbackNext *next_callback_;
|
||||
bool at_first_{true};
|
||||
int nbatches_;
|
||||
struct ArrowSchemaImporter schema_;
|
||||
ArrowColumnarBatchVec batches_;
|
||||
};
|
||||
}; // namespace data
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_ADAPTER_H_
|
||||
|
||||
Reference in New Issue
Block a user