Support arrow through pandas ext types. (#9612)
- Use pandas extension type for pyarrow support. - Additional support for QDM. - Additional support for inplace_predict.
This commit is contained in:
parent
3f2093fb81
commit
60526100e3
@ -172,9 +172,8 @@ Support Matrix
|
||||
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||
| modin.Series | NPA | FF | NPA | NPA | FF | |
|
||||
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||
| pyarrow.Table | T | F | | NPA | FF | |
|
||||
| pyarrow.Table | NPA | NPA | NPA | NPA | NPA | NPA |
|
||||
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||
| pyarrow.dataset.Dataset | T | F | | | F | |
|
||||
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||
| _\_array\_\_ | NPA | F | NPA | NPA | H | |
|
||||
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||
|
||||
@ -552,24 +552,6 @@ XGB_DLL int XGProxyDMatrixSetDataCSR(DMatrixHandle handle, char const *indptr,
|
||||
|
||||
/** @} */ // End of Streaming
|
||||
|
||||
XGB_DLL int XGImportArrowRecordBatch(DataIterHandle data_handle, void *ptr_array, void *ptr_schema);
|
||||
|
||||
/*!
|
||||
* \brief Construct DMatrix from arrow using callbacks. Arrow related C API is not stable
|
||||
* and subject to change in the future.
|
||||
*
|
||||
* \param next Callback function for fetching arrow records.
|
||||
* \param config JSON encoded configuration. Required values are:
|
||||
* - missing: Which value to represent missing value.
|
||||
* - nbatch: Number of batches in arrow table.
|
||||
* - nthread (optional): Number of threads used for initializing DMatrix.
|
||||
* \param out The created DMatrix.
|
||||
*
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char const *config,
|
||||
DMatrixHandle *out);
|
||||
|
||||
/*!
|
||||
* \brief create a new dmatrix from sliced content of existing matrix
|
||||
* \param handle instance of data matrix to be sliced
|
||||
|
||||
@ -2431,6 +2431,8 @@ class Booster:
|
||||
|
||||
from .data import (
|
||||
_array_interface,
|
||||
_arrow_transform,
|
||||
_is_arrow,
|
||||
_is_cudf_df,
|
||||
_is_cupy_array,
|
||||
_is_list,
|
||||
@ -2442,6 +2444,8 @@ class Booster:
|
||||
)
|
||||
|
||||
enable_categorical = True
|
||||
if _is_arrow(data):
|
||||
data = _arrow_transform(data)
|
||||
if _is_pandas_series(data):
|
||||
import pandas as pd
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import ctypes
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, cast
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -676,86 +676,51 @@ def _from_dt_df(
|
||||
|
||||
|
||||
def _is_arrow(data: DataType) -> bool:
|
||||
try:
|
||||
import pyarrow as pa
|
||||
from pyarrow import dataset as arrow_dataset
|
||||
|
||||
return isinstance(data, (pa.Table, arrow_dataset.Dataset))
|
||||
except ImportError:
|
||||
return False
|
||||
return lazy_isinstance(data, "pyarrow.lib", "Table") or lazy_isinstance(
|
||||
data, "pyarrow._dataset", "Dataset"
|
||||
)
|
||||
|
||||
|
||||
def record_batch_data_iter(data_iter: Iterator) -> Callable:
|
||||
"""Data iterator used to ingest Arrow columnar record batches. We are not using
|
||||
class DataIter because it is only intended for building Device DMatrix and external
|
||||
memory DMatrix.
|
||||
|
||||
"""
|
||||
from pyarrow.cffi import ffi
|
||||
|
||||
c_schemas: List[ffi.CData] = []
|
||||
c_arrays: List[ffi.CData] = []
|
||||
|
||||
def _next(data_handle: int) -> int:
|
||||
from pyarrow.cffi import ffi
|
||||
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
c_schemas.append(ffi.new("struct ArrowSchema*"))
|
||||
c_arrays.append(ffi.new("struct ArrowArray*"))
|
||||
ptr_schema = int(ffi.cast("uintptr_t", c_schemas[-1]))
|
||||
ptr_array = int(ffi.cast("uintptr_t", c_arrays[-1]))
|
||||
# pylint: disable=protected-access
|
||||
batch._export_to_c(ptr_array, ptr_schema)
|
||||
_check_call(
|
||||
_LIB.XGImportArrowRecordBatch(
|
||||
ctypes.c_void_p(data_handle),
|
||||
ctypes.c_void_p(ptr_array),
|
||||
ctypes.c_void_p(ptr_schema),
|
||||
)
|
||||
)
|
||||
return 1
|
||||
except StopIteration:
|
||||
return 0
|
||||
|
||||
return _next
|
||||
|
||||
|
||||
def _from_arrow(
|
||||
data: DataType,
|
||||
missing: FloatCompatible,
|
||||
nthread: int,
|
||||
feature_names: Optional[FeatureNames],
|
||||
feature_types: Optional[FeatureTypes],
|
||||
enable_categorical: bool,
|
||||
) -> DispatchedDataBackendReturnType:
|
||||
def _arrow_transform(data: DataType) -> Any:
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from pyarrow.dataset import Dataset
|
||||
|
||||
if not all(
|
||||
pa.types.is_integer(t) or pa.types.is_floating(t) for t in data.schema.types
|
||||
):
|
||||
raise ValueError(
|
||||
"Features in dataset can only be integers or floating point number"
|
||||
)
|
||||
if enable_categorical:
|
||||
raise ValueError("categorical data in arrow is not supported yet.")
|
||||
if isinstance(data, Dataset):
|
||||
raise TypeError("arrow Dataset is not supported.")
|
||||
|
||||
batches = data.to_batches()
|
||||
rb_iter = iter(batches)
|
||||
it = record_batch_data_iter(rb_iter)
|
||||
next_callback = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p)(it)
|
||||
handle = ctypes.c_void_p()
|
||||
config = from_pystr_to_cstr(
|
||||
json.dumps({"missing": missing, "nthread": nthread, "nbatch": len(batches)})
|
||||
)
|
||||
_check_call(
|
||||
_LIB.XGDMatrixCreateFromArrowCallback(
|
||||
next_callback,
|
||||
config,
|
||||
ctypes.byref(handle),
|
||||
)
|
||||
)
|
||||
return handle, feature_names, feature_types
|
||||
data = cast(pa.Table, data)
|
||||
|
||||
def type_mapper(dtype: pa.DataType) -> Optional[str]:
|
||||
"""Maps pyarrow type to pandas arrow extension type."""
|
||||
if pa.types.is_int8(dtype):
|
||||
return pd.ArrowDtype(pa.int8())
|
||||
if pa.types.is_int16(dtype):
|
||||
return pd.ArrowDtype(pa.int16())
|
||||
if pa.types.is_int32(dtype):
|
||||
return pd.ArrowDtype(pa.int32())
|
||||
if pa.types.is_int64(dtype):
|
||||
return pd.ArrowDtype(pa.int64())
|
||||
if pa.types.is_uint8(dtype):
|
||||
return pd.ArrowDtype(pa.uint8())
|
||||
if pa.types.is_uint16(dtype):
|
||||
return pd.ArrowDtype(pa.uint16())
|
||||
if pa.types.is_uint32(dtype):
|
||||
return pd.ArrowDtype(pa.uint32())
|
||||
if pa.types.is_uint64(dtype):
|
||||
return pd.ArrowDtype(pa.uint64())
|
||||
if pa.types.is_float16(dtype):
|
||||
return pd.ArrowDtype(pa.float16())
|
||||
if pa.types.is_float32(dtype):
|
||||
return pd.ArrowDtype(pa.float32())
|
||||
if pa.types.is_float64(dtype):
|
||||
return pd.ArrowDtype(pa.float64())
|
||||
if pa.types.is_boolean(dtype):
|
||||
return pd.ArrowDtype(pa.bool_())
|
||||
return None
|
||||
|
||||
df = data.to_pandas(types_mapper=type_mapper)
|
||||
return df
|
||||
|
||||
|
||||
def _is_cudf_df(data: DataType) -> bool:
|
||||
@ -1081,6 +1046,8 @@ def dispatch_data_backend(
|
||||
return _from_list(data, missing, threads, feature_names, feature_types)
|
||||
if _is_tuple(data):
|
||||
return _from_tuple(data, missing, threads, feature_names, feature_types)
|
||||
if _is_arrow(data):
|
||||
data = _arrow_transform(data)
|
||||
if _is_pandas_series(data):
|
||||
import pandas as pd
|
||||
|
||||
@ -1114,10 +1081,6 @@ def dispatch_data_backend(
|
||||
return _from_pandas_series(
|
||||
data, missing, threads, enable_categorical, feature_names, feature_types
|
||||
)
|
||||
if _is_arrow(data):
|
||||
return _from_arrow(
|
||||
data, missing, threads, feature_names, feature_types, enable_categorical
|
||||
)
|
||||
if _has_array_protocol(data):
|
||||
array = np.asarray(data)
|
||||
return _from_numpy_array(array, missing, threads, feature_names, feature_types)
|
||||
@ -1217,6 +1180,8 @@ def dispatch_meta_backend(
|
||||
if _is_np_array_like(data):
|
||||
_meta_from_numpy(data, name, dtype, handle)
|
||||
return
|
||||
if _is_arrow(data):
|
||||
data = _arrow_transform(data)
|
||||
if _is_pandas_df(data):
|
||||
data, _, _ = _transform_pandas_df(data, False, meta=name, meta_type=dtype)
|
||||
_meta_from_numpy(data, name, dtype, handle)
|
||||
@ -1311,6 +1276,8 @@ def _proxy_transform(
|
||||
import pandas as pd
|
||||
|
||||
data = pd.DataFrame(data)
|
||||
if _is_arrow(data):
|
||||
data = _arrow_transform(data)
|
||||
if _is_pandas_df(data):
|
||||
arr, feature_names, feature_types = _transform_pandas_df(
|
||||
data, enable_categorical, feature_names, feature_types
|
||||
|
||||
@ -532,33 +532,8 @@ XGB_DLL int XGDMatrixCreateFromDT(void** data, const char** feature_stypes,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGImportArrowRecordBatch(DataIterHandle data_handle, void *ptr_array,
|
||||
void *ptr_schema) {
|
||||
API_BEGIN();
|
||||
static_cast<data::RecordBatchesIterAdapter *>(data_handle)
|
||||
->SetData(static_cast<struct ArrowArray *>(ptr_array),
|
||||
static_cast<struct ArrowSchema *>(ptr_schema));
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char const *config,
|
||||
DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
xgboost_CHECK_C_ARG_PTR(config);
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
auto missing = GetMissing(jconfig);
|
||||
auto n_batches = RequiredArg<Integer>(jconfig, "nbatch", __func__);
|
||||
auto n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0);
|
||||
data::RecordBatchesIterAdapter adapter(next, n_batches);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
|
||||
const int* idxset,
|
||||
xgboost::bst_ulong len,
|
||||
DMatrixHandle* out) {
|
||||
XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle, const int *idxset, xgboost::bst_ulong len,
|
||||
DMatrixHandle *out) {
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
return XGDMatrixSliceDMatrixEx(handle, idxset, len, out, 0);
|
||||
}
|
||||
|
||||
@ -20,7 +20,6 @@
|
||||
#include "../common/error_msg.h" // for MaxFeatureSize
|
||||
#include "../common/math.h"
|
||||
#include "array_interface.h"
|
||||
#include "arrow-cdi.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/logging.h"
|
||||
@ -899,306 +898,6 @@ class Column {
|
||||
const uint8_t* bitmap_;
|
||||
};
|
||||
|
||||
// Only columns of primitive types are supported. An ArrowColumnarBatch is a
|
||||
// collection of std::shared_ptr<PrimitiveColumn>. These columns can be of different data types.
|
||||
// Hence, PrimitiveColumn is a class template; and all concrete PrimitiveColumns
|
||||
// derive from the abstract class Column.
|
||||
template <typename T>
|
||||
class PrimitiveColumn : public Column {
|
||||
static constexpr float kNaN = std::numeric_limits<float>::quiet_NaN();
|
||||
|
||||
public:
|
||||
PrimitiveColumn(size_t idx, size_t length, size_t null_count,
|
||||
const uint8_t* bitmap, const T* data, float missing)
|
||||
: Column{idx, length, null_count, bitmap}, data_{data}, missing_{missing} {}
|
||||
|
||||
COOTuple GetElement(size_t row_idx) const override {
|
||||
CHECK(data_ && row_idx < length_) << "Column is empty or out-of-bound index of the column";
|
||||
return { row_idx, col_idx_, IsValidElement(row_idx) ?
|
||||
static_cast<float>(data_[row_idx]) : kNaN };
|
||||
}
|
||||
|
||||
bool IsValidElement(size_t row_idx) const override {
|
||||
// std::isfinite needs to cast to double to prevent msvc report error
|
||||
return IsValid(row_idx)
|
||||
&& std::isfinite(static_cast<double>(data_[row_idx]))
|
||||
&& static_cast<float>(data_[row_idx]) != missing_;
|
||||
}
|
||||
|
||||
std::vector<float> AsFloatVector() const override {
|
||||
CHECK(data_) << "Column is empty";
|
||||
std::vector<float> fv(length_);
|
||||
std::transform(data_, data_ + length_, fv.begin(),
|
||||
[](T v) { return static_cast<float>(v); });
|
||||
return fv;
|
||||
}
|
||||
|
||||
std::vector<uint64_t> AsUint64Vector() const override {
|
||||
CHECK(data_) << "Column is empty";
|
||||
std::vector<uint64_t> iv(length_);
|
||||
std::transform(data_, data_ + length_, iv.begin(),
|
||||
[](T v) { return static_cast<uint64_t>(v); });
|
||||
return iv;
|
||||
}
|
||||
|
||||
private:
|
||||
const T* data_;
|
||||
float missing_; // user specified missing value
|
||||
};
|
||||
|
||||
struct ColumnarMetaInfo {
|
||||
// data type of the column
|
||||
ColumnDType type{ColumnDType::kUnknown};
|
||||
// location of the column in an Arrow record batch
|
||||
int64_t loc{-1};
|
||||
};
|
||||
|
||||
struct ArrowSchemaImporter {
|
||||
std::vector<ColumnarMetaInfo> columns;
|
||||
|
||||
// map Arrow format strings to types
|
||||
static ColumnDType FormatMap(char const* format_str) {
|
||||
CHECK(format_str) << "Format string cannot be empty";
|
||||
switch (format_str[0]) {
|
||||
case 'c':
|
||||
return ColumnDType::kInt8;
|
||||
case 'C':
|
||||
return ColumnDType::kUInt8;
|
||||
case 's':
|
||||
return ColumnDType::kInt16;
|
||||
case 'S':
|
||||
return ColumnDType::kUInt16;
|
||||
case 'i':
|
||||
return ColumnDType::kInt32;
|
||||
case 'I':
|
||||
return ColumnDType::kUInt32;
|
||||
case 'l':
|
||||
return ColumnDType::kInt64;
|
||||
case 'L':
|
||||
return ColumnDType::kUInt64;
|
||||
case 'f':
|
||||
return ColumnDType::kFloat;
|
||||
case 'g':
|
||||
return ColumnDType::kDouble;
|
||||
default:
|
||||
CHECK(false) << "Column data type not supported by XGBoost";
|
||||
return ColumnDType::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
void Import(struct ArrowSchema *schema) {
|
||||
if (schema) {
|
||||
CHECK(std::string(schema->format) == "+s"); // NOLINT
|
||||
CHECK(columns.empty());
|
||||
for (auto i = 0; i < schema->n_children; ++i) {
|
||||
std::string name{schema->children[i]->name};
|
||||
ColumnDType type = FormatMap(schema->children[i]->format);
|
||||
ColumnarMetaInfo col_info{type, i};
|
||||
columns.push_back(col_info);
|
||||
}
|
||||
if (schema->release) {
|
||||
schema->release(schema);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class ArrowColumnarBatch {
|
||||
public:
|
||||
ArrowColumnarBatch(struct ArrowArray *rb, struct ArrowSchemaImporter* schema)
|
||||
: rb_{rb}, schema_{schema} {
|
||||
CHECK(rb_) << "Cannot import non-existent record batch";
|
||||
CHECK(!schema_->columns.empty()) << "Cannot import record batch without a schema";
|
||||
}
|
||||
|
||||
size_t Import(float missing) {
|
||||
auto& infov = schema_->columns;
|
||||
for (size_t i = 0; i < infov.size(); ++i) {
|
||||
columns_.push_back(CreateColumn(i, infov[i], missing));
|
||||
}
|
||||
|
||||
// Compute the starting location for every row in this batch
|
||||
auto batch_size = rb_->length;
|
||||
auto num_columns = columns_.size();
|
||||
row_offsets_.resize(batch_size + 1, 0);
|
||||
for (auto i = 0; i < batch_size; ++i) {
|
||||
row_offsets_[i+1] = row_offsets_[i];
|
||||
for (size_t j = 0; j < num_columns; ++j) {
|
||||
if (GetColumn(j).IsValidElement(i)) {
|
||||
row_offsets_[i+1]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
// return number of elements in the batch
|
||||
return row_offsets_.back();
|
||||
}
|
||||
|
||||
ArrowColumnarBatch(const ArrowColumnarBatch&) = delete;
|
||||
ArrowColumnarBatch& operator=(const ArrowColumnarBatch&) = delete;
|
||||
ArrowColumnarBatch(ArrowColumnarBatch&&) = delete;
|
||||
ArrowColumnarBatch& operator=(ArrowColumnarBatch&&) = delete;
|
||||
|
||||
virtual ~ArrowColumnarBatch() {
|
||||
if (rb_ && rb_->release) {
|
||||
rb_->release(rb_);
|
||||
rb_ = nullptr;
|
||||
}
|
||||
columns_.clear();
|
||||
}
|
||||
|
||||
size_t Size() const { return rb_ ? rb_->length : 0; }
|
||||
|
||||
size_t NumColumns() const { return columns_.size(); }
|
||||
|
||||
size_t NumElements() const { return row_offsets_.back(); }
|
||||
|
||||
const Column& GetColumn(size_t col_idx) const {
|
||||
return *columns_[col_idx];
|
||||
}
|
||||
|
||||
void ShiftRowOffsets(size_t batch_offset) {
|
||||
std::transform(row_offsets_.begin(), row_offsets_.end(), row_offsets_.begin(),
|
||||
[=](size_t c) { return c + batch_offset; });
|
||||
}
|
||||
|
||||
const std::vector<size_t>& RowOffsets() const { return row_offsets_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Column> CreateColumn(size_t idx,
|
||||
ColumnarMetaInfo info,
|
||||
float missing) const {
|
||||
if (info.loc < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto loc_in_batch = info.loc;
|
||||
auto length = rb_->length;
|
||||
auto null_count = rb_->null_count;
|
||||
auto buffers0 = rb_->children[loc_in_batch]->buffers[0];
|
||||
auto buffers1 = rb_->children[loc_in_batch]->buffers[1];
|
||||
const uint8_t* bitmap = buffers0 ? reinterpret_cast<const uint8_t*>(buffers0) : nullptr;
|
||||
const uint8_t* data = buffers1 ? reinterpret_cast<const uint8_t*>(buffers1) : nullptr;
|
||||
|
||||
// if null_count is not computed, compute it here
|
||||
if (null_count < 0) {
|
||||
if (!bitmap) {
|
||||
null_count = 0;
|
||||
} else {
|
||||
null_count = length;
|
||||
for (auto i = 0; i < length; ++i) {
|
||||
if (bitmap[i/8] & (1 << (i%8))) {
|
||||
null_count--;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch (info.type) {
|
||||
case ColumnDType::kInt8:
|
||||
return std::make_shared<PrimitiveColumn<int8_t>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const int8_t*>(data), missing);
|
||||
case ColumnDType::kUInt8:
|
||||
return std::make_shared<PrimitiveColumn<uint8_t>>(
|
||||
idx, length, null_count, bitmap, data, missing);
|
||||
case ColumnDType::kInt16:
|
||||
return std::make_shared<PrimitiveColumn<int16_t>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const int16_t*>(data), missing);
|
||||
case ColumnDType::kUInt16:
|
||||
return std::make_shared<PrimitiveColumn<uint16_t>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const uint16_t*>(data), missing);
|
||||
case ColumnDType::kInt32:
|
||||
return std::make_shared<PrimitiveColumn<int32_t>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const int32_t*>(data), missing);
|
||||
case ColumnDType::kUInt32:
|
||||
return std::make_shared<PrimitiveColumn<uint32_t>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const uint32_t*>(data), missing);
|
||||
case ColumnDType::kInt64:
|
||||
return std::make_shared<PrimitiveColumn<int64_t>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const int64_t*>(data), missing);
|
||||
case ColumnDType::kUInt64:
|
||||
return std::make_shared<PrimitiveColumn<uint64_t>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const uint64_t*>(data), missing);
|
||||
case ColumnDType::kFloat:
|
||||
return std::make_shared<PrimitiveColumn<float>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const float*>(data), missing);
|
||||
case ColumnDType::kDouble:
|
||||
return std::make_shared<PrimitiveColumn<double>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const double*>(data), missing);
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
struct ArrowArray* rb_;
|
||||
struct ArrowSchemaImporter* schema_;
|
||||
std::vector<std::shared_ptr<Column>> columns_;
|
||||
std::vector<size_t> row_offsets_;
|
||||
};
|
||||
|
||||
using ArrowColumnarBatchVec = std::vector<std::unique_ptr<ArrowColumnarBatch>>;
|
||||
class RecordBatchesIterAdapter: public dmlc::DataIter<ArrowColumnarBatchVec> {
|
||||
public:
|
||||
RecordBatchesIterAdapter(XGDMatrixCallbackNext* next_callback, int nbatch)
|
||||
: next_callback_{next_callback}, nbatches_{nbatch} {}
|
||||
|
||||
void BeforeFirst() override {
|
||||
CHECK(at_first_) << "Cannot reset RecordBatchesIterAdapter";
|
||||
}
|
||||
|
||||
bool Next() override {
|
||||
batches_.clear();
|
||||
while (batches_.size() < static_cast<size_t>(nbatches_) && (*next_callback_)(this) != 0) {
|
||||
at_first_ = false;
|
||||
}
|
||||
|
||||
if (batches_.size() > 0) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void SetData(struct ArrowArray* rb, struct ArrowSchema* schema) {
|
||||
// Schema is only imported once at the beginning, regardless how many
|
||||
// baches are comming.
|
||||
// But even schema is not imported we still need to release its C data
|
||||
// exported from Arrow.
|
||||
if (at_first_ && schema) {
|
||||
schema_.Import(schema);
|
||||
} else {
|
||||
if (schema && schema->release) {
|
||||
schema->release(schema);
|
||||
}
|
||||
}
|
||||
if (rb) {
|
||||
batches_.push_back(std::make_unique<ArrowColumnarBatch>(rb, &schema_));
|
||||
}
|
||||
}
|
||||
|
||||
const ArrowColumnarBatchVec& Value() const override {
|
||||
return batches_;
|
||||
}
|
||||
|
||||
size_t NumColumns() const { return schema_.columns.size(); }
|
||||
size_t NumRows() const { return kAdapterUnknownSize; }
|
||||
|
||||
private:
|
||||
XGDMatrixCallbackNext *next_callback_;
|
||||
bool at_first_{true};
|
||||
int nbatches_;
|
||||
struct ArrowSchemaImporter schema_;
|
||||
ArrowColumnarBatchVec batches_;
|
||||
};
|
||||
|
||||
class SparsePageAdapterBatch {
|
||||
HostSparsePageView page_;
|
||||
|
||||
|
||||
@ -1,66 +0,0 @@
|
||||
/* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define ARROW_FLAG_DICTIONARY_ORDERED 1
|
||||
#define ARROW_FLAG_NULLABLE 2
|
||||
#define ARROW_FLAG_MAP_KEYS_SORTED 4
|
||||
|
||||
struct ArrowSchema {
|
||||
// Array type description
|
||||
const char* format;
|
||||
const char* name;
|
||||
const char* metadata;
|
||||
int64_t flags;
|
||||
int64_t n_children;
|
||||
struct ArrowSchema** children;
|
||||
struct ArrowSchema* dictionary;
|
||||
|
||||
// Release callback
|
||||
void (*release)(struct ArrowSchema*);
|
||||
// Opaque producer-specific data
|
||||
void* private_data;
|
||||
};
|
||||
|
||||
struct ArrowArray {
|
||||
// Array data description
|
||||
int64_t length;
|
||||
int64_t null_count;
|
||||
int64_t offset;
|
||||
int64_t n_buffers;
|
||||
int64_t n_children;
|
||||
const void** buffers;
|
||||
struct ArrowArray** children;
|
||||
struct ArrowArray* dictionary;
|
||||
|
||||
// Release callback
|
||||
void (*release)(struct ArrowArray*);
|
||||
// Opaque producer-specific data
|
||||
void* private_data;
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@ -1011,9 +1011,6 @@ template DMatrix* DMatrix::Create<data::CSCArrayAdapter>(data::CSCArrayAdapter*
|
||||
template DMatrix* DMatrix::Create(
|
||||
data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>* adapter,
|
||||
float missing, int nthread, const std::string& cache_prefix, DataSplitMode data_split_mode);
|
||||
template DMatrix* DMatrix::Create<data::RecordBatchesIterAdapter>(
|
||||
data::RecordBatchesIterAdapter* adapter, float missing, int nthread, const std::string&,
|
||||
DataSplitMode data_split_mode);
|
||||
|
||||
SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
|
||||
SparsePage transpose;
|
||||
|
||||
@ -361,78 +361,4 @@ template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing, int n
|
||||
template SimpleDMatrix::SimpleDMatrix(
|
||||
IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>* adapter,
|
||||
float missing, int nthread, DataSplitMode data_split_mode);
|
||||
|
||||
template <>
|
||||
SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, int nthread,
|
||||
DataSplitMode data_split_mode) {
|
||||
Context ctx;
|
||||
ctx.nthread = nthread;
|
||||
|
||||
auto& offset_vec = sparse_page_->offset.HostVector();
|
||||
auto& data_vec = sparse_page_->data.HostVector();
|
||||
uint64_t total_batch_size = 0;
|
||||
uint64_t total_elements = 0;
|
||||
|
||||
adapter->BeforeFirst();
|
||||
// Iterate over batches of input data
|
||||
while (adapter->Next()) {
|
||||
auto& batches = adapter->Value();
|
||||
size_t num_elements = 0;
|
||||
size_t num_rows = 0;
|
||||
// Import Arrow RecordBatches
|
||||
#pragma omp parallel for reduction(+ : num_elements, num_rows) num_threads(ctx.Threads())
|
||||
for (int i = 0; i < static_cast<int>(batches.size()); ++i) { // NOLINT
|
||||
num_elements += batches[i]->Import(missing);
|
||||
num_rows += batches[i]->Size();
|
||||
}
|
||||
total_elements += num_elements;
|
||||
total_batch_size += num_rows;
|
||||
// Compute global offset for every row and starting row for every batch
|
||||
std::vector<uint64_t> batch_offsets(batches.size());
|
||||
for (size_t i = 0; i < batches.size(); ++i) {
|
||||
if (i == 0) {
|
||||
batch_offsets[i] = total_batch_size - num_rows;
|
||||
batches[i]->ShiftRowOffsets(total_elements - num_elements);
|
||||
} else {
|
||||
batch_offsets[i] = batch_offsets[i - 1] + batches[i - 1]->Size();
|
||||
batches[i]->ShiftRowOffsets(batches[i - 1]->RowOffsets().back());
|
||||
}
|
||||
}
|
||||
// Pre-allocate DMatrix memory
|
||||
data_vec.resize(total_elements);
|
||||
offset_vec.resize(total_batch_size + 1);
|
||||
// Copy data into DMatrix
|
||||
#pragma omp parallel num_threads(ctx.Threads())
|
||||
{
|
||||
#pragma omp for nowait
|
||||
for (int i = 0; i < static_cast<int>(batches.size()); ++i) { // NOLINT
|
||||
size_t begin = batches[i]->RowOffsets()[0];
|
||||
for (size_t k = 0; k < batches[i]->Size(); ++k) {
|
||||
for (size_t j = 0; j < batches[i]->NumColumns(); ++j) {
|
||||
auto element = batches[i]->GetColumn(j).GetElement(k);
|
||||
if (!std::isnan(element.value)) {
|
||||
data_vec[begin++] = Entry(element.column_idx, element.value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#pragma omp for nowait
|
||||
for (int i = 0; i < static_cast<int>(batches.size()); ++i) {
|
||||
auto& offsets = batches[i]->RowOffsets();
|
||||
std::copy(offsets.begin() + 1, offsets.end(), offset_vec.begin() + batch_offsets[i] + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Synchronise worker columns
|
||||
info_.num_col_ = adapter->NumColumns();
|
||||
info_.data_split_mode = data_split_mode;
|
||||
ReindexFeatures(&ctx);
|
||||
info_.SynchronizeNumberOfColumns();
|
||||
|
||||
info_.num_row_ = total_batch_size;
|
||||
info_.num_nonzero_ = data_vec.size();
|
||||
CHECK_EQ(offset_vec.back(), info_.num_nonzero_);
|
||||
|
||||
fmat_ctx_ = ctx;
|
||||
}
|
||||
} // namespace xgboost::data
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2015-2023, XGBoost Contributors
|
||||
* \file simple_dmatrix.h
|
||||
* \brief In-memory version of DMatrix.
|
||||
* \author Tianqi Chen
|
||||
@ -15,8 +15,7 @@
|
||||
|
||||
#include "gradient_index.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
namespace xgboost::data {
|
||||
// Used for single batch data.
|
||||
class SimpleDMatrix : public DMatrix {
|
||||
public:
|
||||
@ -75,6 +74,5 @@ class SimpleDMatrix : public DMatrix {
|
||||
// Context used only for DMatrix initialization.
|
||||
Context fmat_ctx_;
|
||||
};
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::data
|
||||
#endif // XGBOOST_DATA_SIMPLE_DMATRIX_H_
|
||||
|
||||
@ -22,7 +22,7 @@ pytestmark = pytest.mark.skipif(
|
||||
dpath = "demo/data/"
|
||||
|
||||
|
||||
class TestArrowTable(unittest.TestCase):
|
||||
class TestArrowTable:
|
||||
def test_arrow_table(self):
|
||||
df = pd.DataFrame(
|
||||
[[0, 1, 2.0, 3.0], [1, 2, 3.0, 4.0]], columns=["a", "b", "c", "d"]
|
||||
@ -52,7 +52,8 @@ class TestArrowTable(unittest.TestCase):
|
||||
assert dm.num_row() == 4
|
||||
assert dm.num_col() == 3
|
||||
|
||||
def test_arrow_train(self):
|
||||
@pytest.mark.parametrize("DMatrixT", [xgb.DMatrix, xgb.QuantileDMatrix])
|
||||
def test_arrow_train(self, DMatrixT):
|
||||
import pandas as pd
|
||||
|
||||
rows = 100
|
||||
@ -64,16 +65,24 @@ class TestArrowTable(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
y = pd.Series(np.random.randn(rows))
|
||||
|
||||
table = pa.Table.from_pandas(X)
|
||||
dtrain1 = xgb.DMatrix(table)
|
||||
dtrain1.set_label(y)
|
||||
dtrain1 = DMatrixT(table)
|
||||
dtrain1.set_label(pa.Table.from_pandas(pd.DataFrame(y)))
|
||||
bst1 = xgb.train({}, dtrain1, num_boost_round=10)
|
||||
preds1 = bst1.predict(xgb.DMatrix(X))
|
||||
dtrain2 = xgb.DMatrix(X, y)
|
||||
preds1 = bst1.predict(DMatrixT(X))
|
||||
|
||||
dtrain2 = DMatrixT(X, y)
|
||||
bst2 = xgb.train({}, dtrain2, num_boost_round=10)
|
||||
preds2 = bst2.predict(xgb.DMatrix(X))
|
||||
preds2 = bst2.predict(DMatrixT(X))
|
||||
|
||||
np.testing.assert_allclose(preds1, preds2)
|
||||
|
||||
preds3 = bst2.inplace_predict(table)
|
||||
np.testing.assert_allclose(preds1, preds3)
|
||||
assert bst2.feature_names == ["A", "B", "C"]
|
||||
assert bst2.feature_types == ["int", "float", "int"]
|
||||
|
||||
def test_arrow_survival(self):
|
||||
data = os.path.join(tm.data_dir(__file__), "veterans_lung_cancer.csv")
|
||||
table = pc.read_csv(data)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user