Support building SimpleDMatrix from Arrow data format (#7512)
* Integrate with Arrow C data API. * Support Arrow dataset. * Support Arrow table. Co-authored-by: Xiaochang Wu <xiaochang.wu@intel.com> Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com> Co-authored-by: Zhang Zhang <zhang.zhang@intel.com>
This commit is contained in:
@@ -2,10 +2,11 @@
|
||||
# pylint: disable=too-many-return-statements, import-error
|
||||
'''Data dispatching for DMatrix.'''
|
||||
import ctypes
|
||||
from distutils import version
|
||||
import json
|
||||
import warnings
|
||||
import os
|
||||
from typing import Any, Tuple, Callable, Optional, List, Union
|
||||
from typing import Any, Tuple, Callable, Optional, List, Union, Iterator
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -466,6 +467,92 @@ def _from_dt_df(
|
||||
return handle, feature_names, feature_types
|
||||
|
||||
|
||||
def _is_arrow(data) -> bool:
|
||||
try:
|
||||
import pyarrow as pa
|
||||
from pyarrow import dataset as arrow_dataset
|
||||
return isinstance(data, (pa.Table, arrow_dataset.Dataset))
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def record_batch_data_iter(data_iter: Iterator) -> Callable:
|
||||
"""Data iterator used to ingest Arrow columnar record batches. We are not using
|
||||
class DataIter because it is only intended for building Device DMatrix and external
|
||||
memory DMatrix.
|
||||
|
||||
"""
|
||||
from pyarrow.cffi import ffi
|
||||
|
||||
c_schemas: List[ffi.CData] = []
|
||||
c_arrays: List[ffi.CData] = []
|
||||
|
||||
def _next(data_handle: int) -> int:
|
||||
from pyarrow.cffi import ffi
|
||||
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
c_schemas.append(ffi.new("struct ArrowSchema*"))
|
||||
c_arrays.append(ffi.new("struct ArrowArray*"))
|
||||
ptr_schema = int(ffi.cast("uintptr_t", c_schemas[-1]))
|
||||
ptr_array = int(ffi.cast("uintptr_t", c_arrays[-1]))
|
||||
# pylint: disable=protected-access
|
||||
batch._export_to_c(ptr_array, ptr_schema)
|
||||
_check_call(
|
||||
_LIB.XGImportArrowRecordBatch(
|
||||
ctypes.c_void_p(data_handle),
|
||||
ctypes.c_void_p(ptr_array),
|
||||
ctypes.c_void_p(ptr_schema),
|
||||
)
|
||||
)
|
||||
return 1
|
||||
except StopIteration:
|
||||
return 0
|
||||
|
||||
return _next
|
||||
|
||||
|
||||
def _from_arrow(
|
||||
data,
|
||||
missing: float,
|
||||
nthread: int,
|
||||
feature_names: Optional[List[str]],
|
||||
feature_types: Optional[List[str]],
|
||||
enable_categorical: bool,
|
||||
) -> Tuple[ctypes.c_void_p, Optional[List[str]], Optional[List[str]]]:
|
||||
import pyarrow as pa
|
||||
|
||||
if not all(
|
||||
pa.types.is_integer(t) or pa.types.is_floating(t) for t in data.schema.types
|
||||
):
|
||||
raise ValueError(
|
||||
"Features in dataset can only be integers or floating point number"
|
||||
)
|
||||
if enable_categorical:
|
||||
raise ValueError("categorical data in arrow is not supported yet.")
|
||||
|
||||
major, _, _ = version.StrictVersion(pa.__version__).version
|
||||
if major == 4:
|
||||
rb_iter = iter(data.to_batches())
|
||||
else:
|
||||
# use_async=True to workaround pyarrow 6.0.1 hang,
|
||||
# see Modin-3982 and ARROW-15362
|
||||
rb_iter = iter(data.to_batches(use_async=True))
|
||||
it = record_batch_data_iter(rb_iter)
|
||||
next_callback = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p)(it)
|
||||
handle = ctypes.c_void_p()
|
||||
|
||||
config = bytes(json.dumps({"missing": missing, "nthread": nthread}), "utf-8")
|
||||
_check_call(
|
||||
_LIB.XGDMatrixCreateFromArrowCallback(
|
||||
next_callback,
|
||||
config,
|
||||
ctypes.byref(handle),
|
||||
)
|
||||
)
|
||||
return handle, feature_names, feature_types
|
||||
|
||||
|
||||
def _is_cudf_df(data) -> bool:
|
||||
return lazy_isinstance(data, "cudf.core.dataframe", "DataFrame")
|
||||
|
||||
@@ -814,6 +901,9 @@ def dispatch_data_backend(
|
||||
return _from_pandas_series(
|
||||
data, missing, threads, enable_categorical, feature_names, feature_types
|
||||
)
|
||||
if _is_arrow(data):
|
||||
return _from_arrow(
|
||||
data, missing, threads, feature_names, feature_types, enable_categorical)
|
||||
if _has_array_protocol(data):
|
||||
array = np.asarray(data)
|
||||
return _from_numpy_array(array, missing, threads, feature_names, feature_types)
|
||||
@@ -954,6 +1044,7 @@ def dispatch_meta_backend(
|
||||
_meta_from_numpy(data, name, dtype, handle)
|
||||
return
|
||||
if _has_array_protocol(data):
|
||||
# pyarrow goes here.
|
||||
array = np.asarray(data)
|
||||
_meta_from_numpy(array, name, dtype, handle)
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user