Support building SimpleDMatrix from Arrow data format (#7512)
* Integrate with Arrow C data API. * Support Arrow dataset. * Support Arrow table. Co-authored-by: Xiaochang Wu <xiaochang.wu@intel.com> Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com> Co-authored-by: Zhang Zhang <zhang.zhang@intel.com>
This commit is contained in:
parent
6b6849b001
commit
613ec36c5a
@ -502,12 +502,29 @@ XGB_DLL int XGProxyDMatrixSetDataCSR(DMatrixHandle handle, char const *indptr,
|
|||||||
char const *indices, char const *data,
|
char const *indices, char const *data,
|
||||||
bst_ulong ncol);
|
bst_ulong ncol);
|
||||||
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* ==========================- End data callback APIs ==========================
|
* ==========================- End data callback APIs ==========================
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
XGB_DLL int XGImportArrowRecordBatch(DataIterHandle data_handle, void *ptr_array, void *ptr_schema);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Construct DMatrix from arrow using callbacks. Arrow related C API is not stable
|
||||||
|
* and subject to change in the future.
|
||||||
|
*
|
||||||
|
* \param next Callback function for fetching arrow records.
|
||||||
|
* \param json_config JSON encoded configuration. Required values are:
|
||||||
|
*
|
||||||
|
* - missing
|
||||||
|
* - nthread
|
||||||
|
*
|
||||||
|
* \param out The created DMatrix.
|
||||||
|
*
|
||||||
|
* \return 0 when success, -1 when failure happens
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char const *json_config,
|
||||||
|
DMatrixHandle *out);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief create a new dmatrix from sliced content of existing matrix
|
* \brief create a new dmatrix from sliced content of existing matrix
|
||||||
|
|||||||
@ -2,10 +2,11 @@
|
|||||||
# pylint: disable=too-many-return-statements, import-error
|
# pylint: disable=too-many-return-statements, import-error
|
||||||
'''Data dispatching for DMatrix.'''
|
'''Data dispatching for DMatrix.'''
|
||||||
import ctypes
|
import ctypes
|
||||||
|
from distutils import version
|
||||||
import json
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
import os
|
import os
|
||||||
from typing import Any, Tuple, Callable, Optional, List, Union
|
from typing import Any, Tuple, Callable, Optional, List, Union, Iterator
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -466,6 +467,92 @@ def _from_dt_df(
|
|||||||
return handle, feature_names, feature_types
|
return handle, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
|
def _is_arrow(data) -> bool:
|
||||||
|
try:
|
||||||
|
import pyarrow as pa
|
||||||
|
from pyarrow import dataset as arrow_dataset
|
||||||
|
return isinstance(data, (pa.Table, arrow_dataset.Dataset))
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def record_batch_data_iter(data_iter: Iterator) -> Callable:
|
||||||
|
"""Data iterator used to ingest Arrow columnar record batches. We are not using
|
||||||
|
class DataIter because it is only intended for building Device DMatrix and external
|
||||||
|
memory DMatrix.
|
||||||
|
|
||||||
|
"""
|
||||||
|
from pyarrow.cffi import ffi
|
||||||
|
|
||||||
|
c_schemas: List[ffi.CData] = []
|
||||||
|
c_arrays: List[ffi.CData] = []
|
||||||
|
|
||||||
|
def _next(data_handle: int) -> int:
|
||||||
|
from pyarrow.cffi import ffi
|
||||||
|
|
||||||
|
try:
|
||||||
|
batch = next(data_iter)
|
||||||
|
c_schemas.append(ffi.new("struct ArrowSchema*"))
|
||||||
|
c_arrays.append(ffi.new("struct ArrowArray*"))
|
||||||
|
ptr_schema = int(ffi.cast("uintptr_t", c_schemas[-1]))
|
||||||
|
ptr_array = int(ffi.cast("uintptr_t", c_arrays[-1]))
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
batch._export_to_c(ptr_array, ptr_schema)
|
||||||
|
_check_call(
|
||||||
|
_LIB.XGImportArrowRecordBatch(
|
||||||
|
ctypes.c_void_p(data_handle),
|
||||||
|
ctypes.c_void_p(ptr_array),
|
||||||
|
ctypes.c_void_p(ptr_schema),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return 1
|
||||||
|
except StopIteration:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return _next
|
||||||
|
|
||||||
|
|
||||||
|
def _from_arrow(
|
||||||
|
data,
|
||||||
|
missing: float,
|
||||||
|
nthread: int,
|
||||||
|
feature_names: Optional[List[str]],
|
||||||
|
feature_types: Optional[List[str]],
|
||||||
|
enable_categorical: bool,
|
||||||
|
) -> Tuple[ctypes.c_void_p, Optional[List[str]], Optional[List[str]]]:
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
if not all(
|
||||||
|
pa.types.is_integer(t) or pa.types.is_floating(t) for t in data.schema.types
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Features in dataset can only be integers or floating point number"
|
||||||
|
)
|
||||||
|
if enable_categorical:
|
||||||
|
raise ValueError("categorical data in arrow is not supported yet.")
|
||||||
|
|
||||||
|
major, _, _ = version.StrictVersion(pa.__version__).version
|
||||||
|
if major == 4:
|
||||||
|
rb_iter = iter(data.to_batches())
|
||||||
|
else:
|
||||||
|
# use_async=True to workaround pyarrow 6.0.1 hang,
|
||||||
|
# see Modin-3982 and ARROW-15362
|
||||||
|
rb_iter = iter(data.to_batches(use_async=True))
|
||||||
|
it = record_batch_data_iter(rb_iter)
|
||||||
|
next_callback = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p)(it)
|
||||||
|
handle = ctypes.c_void_p()
|
||||||
|
|
||||||
|
config = bytes(json.dumps({"missing": missing, "nthread": nthread}), "utf-8")
|
||||||
|
_check_call(
|
||||||
|
_LIB.XGDMatrixCreateFromArrowCallback(
|
||||||
|
next_callback,
|
||||||
|
config,
|
||||||
|
ctypes.byref(handle),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return handle, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
def _is_cudf_df(data) -> bool:
|
def _is_cudf_df(data) -> bool:
|
||||||
return lazy_isinstance(data, "cudf.core.dataframe", "DataFrame")
|
return lazy_isinstance(data, "cudf.core.dataframe", "DataFrame")
|
||||||
|
|
||||||
@ -814,6 +901,9 @@ def dispatch_data_backend(
|
|||||||
return _from_pandas_series(
|
return _from_pandas_series(
|
||||||
data, missing, threads, enable_categorical, feature_names, feature_types
|
data, missing, threads, enable_categorical, feature_names, feature_types
|
||||||
)
|
)
|
||||||
|
if _is_arrow(data):
|
||||||
|
return _from_arrow(
|
||||||
|
data, missing, threads, feature_names, feature_types, enable_categorical)
|
||||||
if _has_array_protocol(data):
|
if _has_array_protocol(data):
|
||||||
array = np.asarray(data)
|
array = np.asarray(data)
|
||||||
return _from_numpy_array(array, missing, threads, feature_names, feature_types)
|
return _from_numpy_array(array, missing, threads, feature_names, feature_types)
|
||||||
@ -954,6 +1044,7 @@ def dispatch_meta_backend(
|
|||||||
_meta_from_numpy(data, name, dtype, handle)
|
_meta_from_numpy(data, name, dtype, handle)
|
||||||
return
|
return
|
||||||
if _has_array_protocol(data):
|
if _has_array_protocol(data):
|
||||||
|
# pyarrow goes here.
|
||||||
array = np.asarray(data)
|
array = np.asarray(data)
|
||||||
_meta_from_numpy(array, name, dtype, handle)
|
_meta_from_numpy(array, name, dtype, handle)
|
||||||
return
|
return
|
||||||
|
|||||||
@ -416,6 +416,27 @@ XGB_DLL int XGDMatrixCreateFromDT(void** data, const char** feature_stypes,
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XGB_DLL int XGImportArrowRecordBatch(DataIterHandle data_handle, void *ptr_array,
|
||||||
|
void *ptr_schema) {
|
||||||
|
API_BEGIN();
|
||||||
|
static_cast<data::RecordBatchesIterAdapter *>(data_handle)
|
||||||
|
->SetData(static_cast<struct ArrowArray *>(ptr_array),
|
||||||
|
static_cast<struct ArrowSchema *>(ptr_schema));
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
|
XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char const *json_config,
|
||||||
|
DMatrixHandle *out) {
|
||||||
|
API_BEGIN();
|
||||||
|
auto config = Json::Load(StringView{json_config});
|
||||||
|
auto missing = GetMissing(config);
|
||||||
|
int32_t n_threads = get<Integer const>(config["nthread"]);
|
||||||
|
n_threads = common::OmpGetNumThreads(n_threads);
|
||||||
|
data::RecordBatchesIterAdapter adapter(next, n_threads);
|
||||||
|
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
|
XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
|
||||||
const int* idxset,
|
const int* idxset,
|
||||||
xgboost::bst_ulong len,
|
xgboost::bst_ulong len,
|
||||||
|
|||||||
@ -13,6 +13,8 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
@ -22,6 +24,7 @@
|
|||||||
#include "array_interface.h"
|
#include "array_interface.h"
|
||||||
#include "../c_api/c_api_error.h"
|
#include "../c_api/c_api_error.h"
|
||||||
#include "../common/math.h"
|
#include "../common/math.h"
|
||||||
|
#include "arrow-cdi.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace data {
|
namespace data {
|
||||||
@ -676,11 +679,10 @@ class FileAdapter : dmlc::DataIter<FileAdapterBatch> {
|
|||||||
template <typename DataIterHandle, typename XGBCallbackDataIterNext, typename XGBoostBatchCSR>
|
template <typename DataIterHandle, typename XGBCallbackDataIterNext, typename XGBoostBatchCSR>
|
||||||
class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
|
class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
|
||||||
public:
|
public:
|
||||||
IteratorAdapter(DataIterHandle data_handle,
|
IteratorAdapter(DataIterHandle data_handle, XGBCallbackDataIterNext* next_callback)
|
||||||
XGBCallbackDataIterNext* next_callback)
|
: columns_{data::kAdapterUnknownSize},
|
||||||
: columns_{data::kAdapterUnknownSize}, row_offset_{0},
|
data_handle_(data_handle),
|
||||||
at_first_(true),
|
next_callback_(next_callback) {}
|
||||||
data_handle_(data_handle), next_callback_(next_callback) {}
|
|
||||||
|
|
||||||
// override functions
|
// override functions
|
||||||
void BeforeFirst() override {
|
void BeforeFirst() override {
|
||||||
@ -766,9 +768,9 @@ class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
|
|||||||
std::vector<dmlc::real_t> value_;
|
std::vector<dmlc::real_t> value_;
|
||||||
|
|
||||||
size_t columns_;
|
size_t columns_;
|
||||||
size_t row_offset_;
|
size_t row_offset_{0};
|
||||||
// at the beginning.
|
// at the beginning.
|
||||||
bool at_first_;
|
bool at_first_{true};
|
||||||
// handle to the iterator,
|
// handle to the iterator,
|
||||||
DataIterHandle data_handle_;
|
DataIterHandle data_handle_;
|
||||||
// call back to get the data.
|
// call back to get the data.
|
||||||
@ -777,6 +779,358 @@ class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
|
|||||||
dmlc::RowBlock<uint32_t> block_;
|
dmlc::RowBlock<uint32_t> block_;
|
||||||
std::unique_ptr<FileAdapterBatch> batch_;
|
std::unique_ptr<FileAdapterBatch> batch_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum ColumnDType : uint8_t {
|
||||||
|
kUnknown,
|
||||||
|
kInt8,
|
||||||
|
kUInt8,
|
||||||
|
kInt16,
|
||||||
|
kUInt16,
|
||||||
|
kInt32,
|
||||||
|
kUInt32,
|
||||||
|
kInt64,
|
||||||
|
kUInt64,
|
||||||
|
kFloat,
|
||||||
|
kDouble
|
||||||
|
};
|
||||||
|
|
||||||
|
class Column {
|
||||||
|
public:
|
||||||
|
Column() = default;
|
||||||
|
|
||||||
|
Column(size_t col_idx, size_t length, size_t null_count, const uint8_t* bitmap)
|
||||||
|
: col_idx_{col_idx}, length_{length}, null_count_{null_count}, bitmap_{bitmap} {}
|
||||||
|
|
||||||
|
virtual ~Column() = default;
|
||||||
|
|
||||||
|
Column(const Column&) = delete;
|
||||||
|
Column& operator=(const Column&) = delete;
|
||||||
|
Column(Column&&) = delete;
|
||||||
|
Column& operator=(Column&&) = delete;
|
||||||
|
|
||||||
|
// whether the valid bit is set for this element
|
||||||
|
bool IsValid(size_t row_idx) const {
|
||||||
|
return (!bitmap_ || (bitmap_[row_idx/8] & (1 << (row_idx%8))));
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual COOTuple GetElement(size_t row_idx) const = 0;
|
||||||
|
|
||||||
|
virtual bool IsValidElement(size_t row_idx) const = 0;
|
||||||
|
|
||||||
|
virtual std::vector<float> AsFloatVector() const = 0;
|
||||||
|
|
||||||
|
virtual std::vector<uint64_t> AsUint64Vector() const = 0;
|
||||||
|
|
||||||
|
size_t Length() const { return length_; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
size_t col_idx_;
|
||||||
|
size_t length_;
|
||||||
|
size_t null_count_;
|
||||||
|
const uint8_t* bitmap_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Only columns of primitive types are supported. An ArrowColumnarBatch is a
|
||||||
|
// collection of std::shared_ptr<PrimitiveColumn>. These columns can be of different data types.
|
||||||
|
// Hence, PrimitiveColumn is a class template; and all concrete PrimitiveColumns
|
||||||
|
// derive from the abstract class Column.
|
||||||
|
template <typename T>
|
||||||
|
class PrimitiveColumn : public Column {
|
||||||
|
static constexpr float kNaN = std::numeric_limits<float>::quiet_NaN();
|
||||||
|
|
||||||
|
public:
|
||||||
|
PrimitiveColumn(size_t idx, size_t length, size_t null_count,
|
||||||
|
const uint8_t* bitmap, const T* data, float missing)
|
||||||
|
: Column{idx, length, null_count, bitmap}, data_{data}, missing_{missing} {}
|
||||||
|
|
||||||
|
COOTuple GetElement(size_t row_idx) const override {
|
||||||
|
CHECK(data_ && row_idx < length_) << "Column is empty or out-of-bound index of the column";
|
||||||
|
return { row_idx, col_idx_, IsValidElement(row_idx) ?
|
||||||
|
static_cast<float>(data_[row_idx]) : kNaN };
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsValidElement(size_t row_idx) const override {
|
||||||
|
// std::isfinite needs to cast to double to prevent msvc report error
|
||||||
|
return IsValid(row_idx)
|
||||||
|
&& std::isfinite(static_cast<double>(data_[row_idx]))
|
||||||
|
&& static_cast<float>(data_[row_idx]) != missing_;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> AsFloatVector() const override {
|
||||||
|
CHECK(data_) << "Column is empty";
|
||||||
|
std::vector<float> fv(length_);
|
||||||
|
std::transform(data_, data_ + length_, fv.begin(),
|
||||||
|
[](T v) { return static_cast<float>(v); });
|
||||||
|
return fv;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<uint64_t> AsUint64Vector() const override {
|
||||||
|
CHECK(data_) << "Column is empty";
|
||||||
|
std::vector<uint64_t> iv(length_);
|
||||||
|
std::transform(data_, data_ + length_, iv.begin(),
|
||||||
|
[](T v) { return static_cast<uint64_t>(v); });
|
||||||
|
return iv;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const T* data_;
|
||||||
|
float missing_; // user specified missing value
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ColumnarMetaInfo {
|
||||||
|
// data type of the column
|
||||||
|
ColumnDType type{ColumnDType::kUnknown};
|
||||||
|
// location of the column in an Arrow record batch
|
||||||
|
int64_t loc{-1};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArrowSchemaImporter {
|
||||||
|
std::vector<ColumnarMetaInfo> columns;
|
||||||
|
|
||||||
|
// map Arrow format strings to types
|
||||||
|
static ColumnDType FormatMap(char const* format_str) {
|
||||||
|
CHECK(format_str) << "Format string cannot be empty";
|
||||||
|
switch (format_str[0]) {
|
||||||
|
case 'c':
|
||||||
|
return ColumnDType::kInt8;
|
||||||
|
case 'C':
|
||||||
|
return ColumnDType::kUInt8;
|
||||||
|
case 's':
|
||||||
|
return ColumnDType::kInt16;
|
||||||
|
case 'S':
|
||||||
|
return ColumnDType::kUInt16;
|
||||||
|
case 'i':
|
||||||
|
return ColumnDType::kInt32;
|
||||||
|
case 'I':
|
||||||
|
return ColumnDType::kUInt32;
|
||||||
|
case 'l':
|
||||||
|
return ColumnDType::kInt64;
|
||||||
|
case 'L':
|
||||||
|
return ColumnDType::kUInt64;
|
||||||
|
case 'f':
|
||||||
|
return ColumnDType::kFloat;
|
||||||
|
case 'g':
|
||||||
|
return ColumnDType::kDouble;
|
||||||
|
default:
|
||||||
|
CHECK(false) << "Column data type not supported by XGBoost";
|
||||||
|
return ColumnDType::kUnknown;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Import(struct ArrowSchema *schema) {
|
||||||
|
if (schema) {
|
||||||
|
CHECK(std::string(schema->format) == "+s"); // NOLINT
|
||||||
|
CHECK(columns.empty());
|
||||||
|
for (auto i = 0; i < schema->n_children; ++i) {
|
||||||
|
std::string name{schema->children[i]->name};
|
||||||
|
ColumnDType type = FormatMap(schema->children[i]->format);
|
||||||
|
ColumnarMetaInfo col_info{type, i};
|
||||||
|
columns.push_back(col_info);
|
||||||
|
}
|
||||||
|
if (schema->release) {
|
||||||
|
schema->release(schema);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ArrowColumnarBatch {
|
||||||
|
public:
|
||||||
|
ArrowColumnarBatch(struct ArrowArray *rb, struct ArrowSchemaImporter* schema)
|
||||||
|
: rb_{rb}, schema_{schema} {
|
||||||
|
CHECK(rb_) << "Cannot import non-existent record batch";
|
||||||
|
CHECK(!schema_->columns.empty()) << "Cannot import record batch without a schema";
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t Import(float missing) {
|
||||||
|
auto& infov = schema_->columns;
|
||||||
|
for (size_t i = 0; i < infov.size(); ++i) {
|
||||||
|
columns_.push_back(CreateColumn(i, infov[i], missing));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the starting location for every row in this batch
|
||||||
|
auto batch_size = rb_->length;
|
||||||
|
auto num_columns = columns_.size();
|
||||||
|
row_offsets_.resize(batch_size + 1, 0);
|
||||||
|
for (auto i = 0; i < batch_size; ++i) {
|
||||||
|
row_offsets_[i+1] = row_offsets_[i];
|
||||||
|
for (size_t j = 0; j < num_columns; ++j) {
|
||||||
|
if (GetColumn(j).IsValidElement(i)) {
|
||||||
|
row_offsets_[i+1]++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// return number of elements in the batch
|
||||||
|
return row_offsets_.back();
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrowColumnarBatch(const ArrowColumnarBatch&) = delete;
|
||||||
|
ArrowColumnarBatch& operator=(const ArrowColumnarBatch&) = delete;
|
||||||
|
ArrowColumnarBatch(ArrowColumnarBatch&&) = delete;
|
||||||
|
ArrowColumnarBatch& operator=(ArrowColumnarBatch&&) = delete;
|
||||||
|
|
||||||
|
virtual ~ArrowColumnarBatch() {
|
||||||
|
if (rb_ && rb_->release) {
|
||||||
|
rb_->release(rb_);
|
||||||
|
rb_ = nullptr;
|
||||||
|
}
|
||||||
|
columns_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t Size() const { return rb_ ? rb_->length : 0; }
|
||||||
|
|
||||||
|
size_t NumColumns() const { return columns_.size(); }
|
||||||
|
|
||||||
|
size_t NumElements() const { return row_offsets_.back(); }
|
||||||
|
|
||||||
|
const Column& GetColumn(size_t col_idx) const {
|
||||||
|
return *columns_[col_idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
void ShiftRowOffsets(size_t batch_offset) {
|
||||||
|
std::transform(row_offsets_.begin(), row_offsets_.end(), row_offsets_.begin(),
|
||||||
|
[=](size_t c) { return c + batch_offset; });
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<size_t>& RowOffsets() const { return row_offsets_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<Column> CreateColumn(size_t idx,
|
||||||
|
ColumnarMetaInfo info,
|
||||||
|
float missing) const {
|
||||||
|
if (info.loc < 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto loc_in_batch = info.loc;
|
||||||
|
auto length = rb_->length;
|
||||||
|
auto null_count = rb_->null_count;
|
||||||
|
auto buffers0 = rb_->children[loc_in_batch]->buffers[0];
|
||||||
|
auto buffers1 = rb_->children[loc_in_batch]->buffers[1];
|
||||||
|
const uint8_t* bitmap = buffers0 ? reinterpret_cast<const uint8_t*>(buffers0) : nullptr;
|
||||||
|
const uint8_t* data = buffers1 ? reinterpret_cast<const uint8_t*>(buffers1) : nullptr;
|
||||||
|
|
||||||
|
// if null_count is not computed, compute it here
|
||||||
|
if (null_count < 0) {
|
||||||
|
if (!bitmap) {
|
||||||
|
null_count = 0;
|
||||||
|
} else {
|
||||||
|
null_count = length;
|
||||||
|
for (auto i = 0; i < length; ++i) {
|
||||||
|
if (bitmap[i/8] & (1 << (i%8))) {
|
||||||
|
null_count--;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (info.type) {
|
||||||
|
case ColumnDType::kInt8:
|
||||||
|
return std::make_shared<PrimitiveColumn<int8_t>>(
|
||||||
|
idx, length, null_count, bitmap,
|
||||||
|
reinterpret_cast<const int8_t*>(data), missing);
|
||||||
|
case ColumnDType::kUInt8:
|
||||||
|
return std::make_shared<PrimitiveColumn<uint8_t>>(
|
||||||
|
idx, length, null_count, bitmap, data, missing);
|
||||||
|
case ColumnDType::kInt16:
|
||||||
|
return std::make_shared<PrimitiveColumn<int16_t>>(
|
||||||
|
idx, length, null_count, bitmap,
|
||||||
|
reinterpret_cast<const int16_t*>(data), missing);
|
||||||
|
case ColumnDType::kUInt16:
|
||||||
|
return std::make_shared<PrimitiveColumn<uint16_t>>(
|
||||||
|
idx, length, null_count, bitmap,
|
||||||
|
reinterpret_cast<const uint16_t*>(data), missing);
|
||||||
|
case ColumnDType::kInt32:
|
||||||
|
return std::make_shared<PrimitiveColumn<int32_t>>(
|
||||||
|
idx, length, null_count, bitmap,
|
||||||
|
reinterpret_cast<const int32_t*>(data), missing);
|
||||||
|
case ColumnDType::kUInt32:
|
||||||
|
return std::make_shared<PrimitiveColumn<uint32_t>>(
|
||||||
|
idx, length, null_count, bitmap,
|
||||||
|
reinterpret_cast<const uint32_t*>(data), missing);
|
||||||
|
case ColumnDType::kInt64:
|
||||||
|
return std::make_shared<PrimitiveColumn<int64_t>>(
|
||||||
|
idx, length, null_count, bitmap,
|
||||||
|
reinterpret_cast<const int64_t*>(data), missing);
|
||||||
|
case ColumnDType::kUInt64:
|
||||||
|
return std::make_shared<PrimitiveColumn<uint64_t>>(
|
||||||
|
idx, length, null_count, bitmap,
|
||||||
|
reinterpret_cast<const uint64_t*>(data), missing);
|
||||||
|
case ColumnDType::kFloat:
|
||||||
|
return std::make_shared<PrimitiveColumn<float>>(
|
||||||
|
idx, length, null_count, bitmap,
|
||||||
|
reinterpret_cast<const float*>(data), missing);
|
||||||
|
case ColumnDType::kDouble:
|
||||||
|
return std::make_shared<PrimitiveColumn<double>>(
|
||||||
|
idx, length, null_count, bitmap,
|
||||||
|
reinterpret_cast<const double*>(data), missing);
|
||||||
|
default:
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ArrowArray* rb_;
|
||||||
|
struct ArrowSchemaImporter* schema_;
|
||||||
|
std::vector<std::shared_ptr<Column>> columns_;
|
||||||
|
std::vector<size_t> row_offsets_;
|
||||||
|
};
|
||||||
|
|
||||||
|
using ArrowColumnarBatchVec = std::vector<std::unique_ptr<ArrowColumnarBatch>>;
|
||||||
|
class RecordBatchesIterAdapter: public dmlc::DataIter<ArrowColumnarBatchVec> {
|
||||||
|
public:
|
||||||
|
RecordBatchesIterAdapter(XGDMatrixCallbackNext *next_callback,
|
||||||
|
int nthread)
|
||||||
|
: next_callback_{next_callback},
|
||||||
|
nbatches_{nthread} {}
|
||||||
|
|
||||||
|
void BeforeFirst() override {
|
||||||
|
CHECK(at_first_) << "Cannot reset RecordBatchesIterAdapter";
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Next() override {
|
||||||
|
batches_.clear();
|
||||||
|
while (batches_.size() < static_cast<size_t>(nbatches_) && (*next_callback_)(this) != 0) {
|
||||||
|
at_first_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (batches_.size() > 0) {
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetData(struct ArrowArray* rb, struct ArrowSchema* schema) {
|
||||||
|
// Schema is only imported once at the beginning, regardless how many
|
||||||
|
// baches are comming.
|
||||||
|
// But even schema is not imported we still need to release its C data
|
||||||
|
// exported from Arrow.
|
||||||
|
if (at_first_ && schema) {
|
||||||
|
schema_.Import(schema);
|
||||||
|
} else {
|
||||||
|
if (schema && schema->release) {
|
||||||
|
schema->release(schema);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (rb) {
|
||||||
|
batches_.push_back(std::make_unique<ArrowColumnarBatch>(rb, &schema_));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const ArrowColumnarBatchVec& Value() const override {
|
||||||
|
return batches_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t NumColumns() const { return schema_.columns.size(); }
|
||||||
|
size_t NumRows() const { return kAdapterUnknownSize; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
XGDMatrixCallbackNext *next_callback_;
|
||||||
|
bool at_first_{true};
|
||||||
|
int nbatches_;
|
||||||
|
struct ArrowSchemaImporter schema_;
|
||||||
|
ArrowColumnarBatchVec batches_;
|
||||||
|
};
|
||||||
}; // namespace data
|
}; // namespace data
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
#endif // XGBOOST_DATA_ADAPTER_H_
|
#endif // XGBOOST_DATA_ADAPTER_H_
|
||||||
|
|||||||
66
src/data/arrow-cdi.h
Normal file
66
src/data/arrow-cdi.h
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
/* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define ARROW_FLAG_DICTIONARY_ORDERED 1
|
||||||
|
#define ARROW_FLAG_NULLABLE 2
|
||||||
|
#define ARROW_FLAG_MAP_KEYS_SORTED 4
|
||||||
|
|
||||||
|
struct ArrowSchema {
|
||||||
|
// Array type description
|
||||||
|
const char* format;
|
||||||
|
const char* name;
|
||||||
|
const char* metadata;
|
||||||
|
int64_t flags;
|
||||||
|
int64_t n_children;
|
||||||
|
struct ArrowSchema** children;
|
||||||
|
struct ArrowSchema* dictionary;
|
||||||
|
|
||||||
|
// Release callback
|
||||||
|
void (*release)(struct ArrowSchema*);
|
||||||
|
// Opaque producer-specific data
|
||||||
|
void* private_data;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArrowArray {
|
||||||
|
// Array data description
|
||||||
|
int64_t length;
|
||||||
|
int64_t null_count;
|
||||||
|
int64_t offset;
|
||||||
|
int64_t n_buffers;
|
||||||
|
int64_t n_children;
|
||||||
|
const void** buffers;
|
||||||
|
struct ArrowArray** children;
|
||||||
|
struct ArrowArray* dictionary;
|
||||||
|
|
||||||
|
// Release callback
|
||||||
|
void (*release)(struct ArrowArray*);
|
||||||
|
// Opaque producer-specific data
|
||||||
|
void* private_data;
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
@ -1000,6 +1000,8 @@ template DMatrix *
|
|||||||
DMatrix::Create(data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext,
|
DMatrix::Create(data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext,
|
||||||
XGBoostBatchCSR> *adapter,
|
XGBoostBatchCSR> *adapter,
|
||||||
float missing, int nthread, const std::string &cache_prefix);
|
float missing, int nthread, const std::string &cache_prefix);
|
||||||
|
template DMatrix* DMatrix::Create<data::RecordBatchesIterAdapter>(
|
||||||
|
data::RecordBatchesIterAdapter* adapter, float missing, int nthread, const std::string&);
|
||||||
|
|
||||||
SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
|
SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
|
||||||
SparsePage transpose;
|
SparsePage transpose;
|
||||||
|
|||||||
@ -249,5 +249,70 @@ template SimpleDMatrix::SimpleDMatrix(
|
|||||||
IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>
|
IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>
|
||||||
*adapter,
|
*adapter,
|
||||||
float missing, int nthread);
|
float missing, int nthread);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, int nthread) {
|
||||||
|
auto& offset_vec = sparse_page_->offset.HostVector();
|
||||||
|
auto& data_vec = sparse_page_->data.HostVector();
|
||||||
|
uint64_t total_batch_size = 0;
|
||||||
|
uint64_t total_elements = 0;
|
||||||
|
|
||||||
|
adapter->BeforeFirst();
|
||||||
|
// Iterate over batches of input data
|
||||||
|
while (adapter->Next()) {
|
||||||
|
auto& batches = adapter->Value();
|
||||||
|
size_t num_elements = 0;
|
||||||
|
size_t num_rows = 0;
|
||||||
|
// Import Arrow RecordBatches
|
||||||
|
#pragma omp parallel for reduction(+ : num_elements, num_rows) num_threads(nthread)
|
||||||
|
for (int i = 0; i < static_cast<int>(batches.size()); ++i) { // NOLINT
|
||||||
|
num_elements += batches[i]->Import(missing);
|
||||||
|
num_rows += batches[i]->Size();
|
||||||
|
}
|
||||||
|
total_elements += num_elements;
|
||||||
|
total_batch_size += num_rows;
|
||||||
|
// Compute global offset for every row and starting row for every batch
|
||||||
|
std::vector<uint64_t> batch_offsets(batches.size());
|
||||||
|
for (size_t i = 0; i < batches.size(); ++i) {
|
||||||
|
if (i == 0) {
|
||||||
|
batch_offsets[i] = total_batch_size - num_rows;
|
||||||
|
batches[i]->ShiftRowOffsets(total_elements - num_elements);
|
||||||
|
} else {
|
||||||
|
batch_offsets[i] = batch_offsets[i - 1] + batches[i - 1]->Size();
|
||||||
|
batches[i]->ShiftRowOffsets(batches[i - 1]->RowOffsets().back());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Pre-allocate DMatrix memory
|
||||||
|
data_vec.resize(total_elements);
|
||||||
|
offset_vec.resize(total_batch_size + 1);
|
||||||
|
// Copy data into DMatrix
|
||||||
|
#pragma omp parallel num_threads(nthread)
|
||||||
|
{
|
||||||
|
#pragma omp for nowait
|
||||||
|
for (int i = 0; i < static_cast<int>(batches.size()); ++i) { // NOLINT
|
||||||
|
size_t begin = batches[i]->RowOffsets()[0];
|
||||||
|
for (size_t k = 0; k < batches[i]->Size(); ++k) {
|
||||||
|
for (size_t j = 0; j < batches[i]->NumColumns(); ++j) {
|
||||||
|
auto element = batches[i]->GetColumn(j).GetElement(k);
|
||||||
|
if (!std::isnan(element.value)) {
|
||||||
|
data_vec[begin++] = Entry(element.column_idx, element.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#pragma omp for nowait
|
||||||
|
for (int i = 0; i < static_cast<int>(batches.size()); ++i) {
|
||||||
|
auto& offsets = batches[i]->RowOffsets();
|
||||||
|
std::copy(offsets.begin() + 1, offsets.end(), offset_vec.begin() + batch_offsets[i] + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Synchronise worker columns
|
||||||
|
info_.num_col_ = adapter->NumColumns();
|
||||||
|
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
|
||||||
|
info_.num_row_ = total_batch_size;
|
||||||
|
info_.num_nonzero_ = data_vec.size();
|
||||||
|
CHECK_EQ(offset_vec.back(), info_.num_nonzero_);
|
||||||
|
}
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -26,6 +26,8 @@ dependencies:
|
|||||||
- awscli
|
- awscli
|
||||||
- numba
|
- numba
|
||||||
- llvmlite
|
- llvmlite
|
||||||
|
- cffi
|
||||||
|
- pyarrow
|
||||||
- pip:
|
- pip:
|
||||||
- shap
|
- shap
|
||||||
- awscli
|
- awscli
|
||||||
|
|||||||
@ -33,6 +33,8 @@ dependencies:
|
|||||||
- numba
|
- numba
|
||||||
- llvmlite
|
- llvmlite
|
||||||
- py-ubjson
|
- py-ubjson
|
||||||
|
- cffi
|
||||||
|
- pyarrow
|
||||||
- pip:
|
- pip:
|
||||||
- shap
|
- shap
|
||||||
- ipython # required by shap at import time.
|
- ipython # required by shap at import time.
|
||||||
|
|||||||
@ -33,6 +33,8 @@ dependencies:
|
|||||||
- boto3
|
- boto3
|
||||||
- awscli
|
- awscli
|
||||||
- py-ubjson
|
- py-ubjson
|
||||||
|
- cffi
|
||||||
|
- pyarrow
|
||||||
- pip:
|
- pip:
|
||||||
- sphinx_rtd_theme
|
- sphinx_rtd_theme
|
||||||
- datatable
|
- datatable
|
||||||
|
|||||||
@ -15,7 +15,8 @@ dependencies:
|
|||||||
- pytest
|
- pytest
|
||||||
- jsonschema
|
- jsonschema
|
||||||
- hypothesis
|
- hypothesis
|
||||||
- jsonschema
|
|
||||||
- python-graphviz
|
- python-graphviz
|
||||||
- pip
|
- pip
|
||||||
- py-ubjson
|
- py-ubjson
|
||||||
|
- cffi
|
||||||
|
- pyarrow
|
||||||
|
|||||||
@ -17,3 +17,5 @@ dependencies:
|
|||||||
- modin-ray
|
- modin-ray
|
||||||
- pip
|
- pip
|
||||||
- py-ubjson
|
- py-ubjson
|
||||||
|
- cffi
|
||||||
|
- pyarrow
|
||||||
|
|||||||
88
tests/python/test_with_arrow.py
Normal file
88
tests/python/test_with_arrow.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
import unittest
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
import testing as tm
|
||||||
|
import xgboost as xgb
|
||||||
|
import os
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pyarrow as pa
|
||||||
|
import pyarrow.csv as pc
|
||||||
|
import pandas as pd
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.skipif(
|
||||||
|
tm.no_arrow()["condition"] or tm.no_pandas()["condition"],
|
||||||
|
reason=tm.no_arrow()["reason"] + " or " + tm.no_pandas()["reason"],
|
||||||
|
)
|
||||||
|
|
||||||
|
dpath = "demo/data/"
|
||||||
|
|
||||||
|
|
||||||
|
class TestArrowTable(unittest.TestCase):
|
||||||
|
def test_arrow_table(self):
|
||||||
|
df = pd.DataFrame(
|
||||||
|
[[0, 1, 2.0, 3.0], [1, 2, 3.0, 4.0]], columns=["a", "b", "c", "d"]
|
||||||
|
)
|
||||||
|
table = pa.Table.from_pandas(df)
|
||||||
|
dm = xgb.DMatrix(table)
|
||||||
|
assert dm.num_row() == 2
|
||||||
|
assert dm.num_col() == 4
|
||||||
|
|
||||||
|
def test_arrow_table_with_label(self):
|
||||||
|
df = pd.DataFrame([[1, 2.0, 3.0], [2, 3.0, 4.0]], columns=["a", "b", "c"])
|
||||||
|
table = pa.Table.from_pandas(df)
|
||||||
|
label = np.array([0, 1])
|
||||||
|
dm = xgb.DMatrix(table)
|
||||||
|
dm.set_label(label)
|
||||||
|
assert dm.num_row() == 2
|
||||||
|
assert dm.num_col() == 3
|
||||||
|
np.testing.assert_array_equal(dm.get_label(), np.array([0, 1]))
|
||||||
|
|
||||||
|
def test_arrow_table_from_np(self):
|
||||||
|
coldata = np.array(
|
||||||
|
[[1.0, 1.0, 0.0, 0.0], [2.0, 0.0, 1.0, 0.0], [3.0, 0.0, 0.0, 1.0]]
|
||||||
|
)
|
||||||
|
cols = list(map(pa.array, coldata))
|
||||||
|
table = pa.Table.from_arrays(cols, ["a", "b", "c"])
|
||||||
|
dm = xgb.DMatrix(table)
|
||||||
|
assert dm.num_row() == 4
|
||||||
|
assert dm.num_col() == 3
|
||||||
|
|
||||||
|
def test_arrow_train(self):
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
rows = 100
|
||||||
|
X = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"A": np.random.randint(0, 10, size=rows),
|
||||||
|
"B": np.random.randn(rows),
|
||||||
|
"C": np.random.permutation([1, 0] * (rows // 2)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
y = pd.Series(np.random.randn(rows))
|
||||||
|
table = pa.Table.from_pandas(X)
|
||||||
|
dtrain1 = xgb.DMatrix(table)
|
||||||
|
dtrain1.set_label(y)
|
||||||
|
bst1 = xgb.train({}, dtrain1, num_boost_round=10)
|
||||||
|
preds1 = bst1.predict(xgb.DMatrix(X))
|
||||||
|
dtrain2 = xgb.DMatrix(X, y)
|
||||||
|
bst2 = xgb.train({}, dtrain2, num_boost_round=10)
|
||||||
|
preds2 = bst2.predict(xgb.DMatrix(X))
|
||||||
|
np.testing.assert_allclose(preds1, preds2)
|
||||||
|
|
||||||
|
def test_arrow_survival(self):
|
||||||
|
data = os.path.join(tm.PROJECT_ROOT, "demo", "data", "veterans_lung_cancer.csv")
|
||||||
|
table = pc.read_csv(data)
|
||||||
|
y_lower_bound = table["Survival_label_lower_bound"]
|
||||||
|
y_upper_bound = table["Survival_label_upper_bound"]
|
||||||
|
X = table.drop(["Survival_label_lower_bound", "Survival_label_upper_bound"])
|
||||||
|
|
||||||
|
dtrain = xgb.DMatrix(
|
||||||
|
X, label_lower_bound=y_lower_bound, label_upper_bound=y_upper_bound
|
||||||
|
)
|
||||||
|
y_np_up = dtrain.get_float_info("label_upper_bound")
|
||||||
|
y_np_low = dtrain.get_float_info("label_lower_bound")
|
||||||
|
np.testing.assert_equal(y_np_up, y_upper_bound.to_pandas().values)
|
||||||
|
np.testing.assert_equal(y_np_low, y_lower_bound.to_pandas().values)
|
||||||
@ -53,6 +53,15 @@ def no_pandas():
|
|||||||
'reason': 'Pandas is not installed.'}
|
'reason': 'Pandas is not installed.'}
|
||||||
|
|
||||||
|
|
||||||
|
def no_arrow():
|
||||||
|
reason = "pyarrow is not installed"
|
||||||
|
try:
|
||||||
|
import pyarrow # noqa
|
||||||
|
return {"condition": False, "reason": reason}
|
||||||
|
except ImportError:
|
||||||
|
return {"condition": True, "reason": reason}
|
||||||
|
|
||||||
|
|
||||||
def no_modin():
|
def no_modin():
|
||||||
reason = 'Modin is not installed.'
|
reason = 'Modin is not installed.'
|
||||||
try:
|
try:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user