Support arrow through pandas ext types. (#9612)
- Use pandas extension type for pyarrow support. - Additional support for QDM. - Additional support for inplace_predict.
This commit is contained in:
@@ -2431,6 +2431,8 @@ class Booster:
|
||||
|
||||
from .data import (
|
||||
_array_interface,
|
||||
_arrow_transform,
|
||||
_is_arrow,
|
||||
_is_cudf_df,
|
||||
_is_cupy_array,
|
||||
_is_list,
|
||||
@@ -2442,6 +2444,8 @@ class Booster:
|
||||
)
|
||||
|
||||
enable_categorical = True
|
||||
if _is_arrow(data):
|
||||
data = _arrow_transform(data)
|
||||
if _is_pandas_series(data):
|
||||
import pandas as pd
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import ctypes
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, cast
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -676,86 +676,51 @@ def _from_dt_df(
|
||||
|
||||
|
||||
def _is_arrow(data: DataType) -> bool:
|
||||
try:
|
||||
import pyarrow as pa
|
||||
from pyarrow import dataset as arrow_dataset
|
||||
|
||||
return isinstance(data, (pa.Table, arrow_dataset.Dataset))
|
||||
except ImportError:
|
||||
return False
|
||||
return lazy_isinstance(data, "pyarrow.lib", "Table") or lazy_isinstance(
|
||||
data, "pyarrow._dataset", "Dataset"
|
||||
)
|
||||
|
||||
|
||||
def record_batch_data_iter(data_iter: Iterator) -> Callable:
|
||||
"""Data iterator used to ingest Arrow columnar record batches. We are not using
|
||||
class DataIter because it is only intended for building Device DMatrix and external
|
||||
memory DMatrix.
|
||||
|
||||
"""
|
||||
from pyarrow.cffi import ffi
|
||||
|
||||
c_schemas: List[ffi.CData] = []
|
||||
c_arrays: List[ffi.CData] = []
|
||||
|
||||
def _next(data_handle: int) -> int:
|
||||
from pyarrow.cffi import ffi
|
||||
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
c_schemas.append(ffi.new("struct ArrowSchema*"))
|
||||
c_arrays.append(ffi.new("struct ArrowArray*"))
|
||||
ptr_schema = int(ffi.cast("uintptr_t", c_schemas[-1]))
|
||||
ptr_array = int(ffi.cast("uintptr_t", c_arrays[-1]))
|
||||
# pylint: disable=protected-access
|
||||
batch._export_to_c(ptr_array, ptr_schema)
|
||||
_check_call(
|
||||
_LIB.XGImportArrowRecordBatch(
|
||||
ctypes.c_void_p(data_handle),
|
||||
ctypes.c_void_p(ptr_array),
|
||||
ctypes.c_void_p(ptr_schema),
|
||||
)
|
||||
)
|
||||
return 1
|
||||
except StopIteration:
|
||||
return 0
|
||||
|
||||
return _next
|
||||
|
||||
|
||||
def _from_arrow(
|
||||
data: DataType,
|
||||
missing: FloatCompatible,
|
||||
nthread: int,
|
||||
feature_names: Optional[FeatureNames],
|
||||
feature_types: Optional[FeatureTypes],
|
||||
enable_categorical: bool,
|
||||
) -> DispatchedDataBackendReturnType:
|
||||
def _arrow_transform(data: DataType) -> Any:
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from pyarrow.dataset import Dataset
|
||||
|
||||
if not all(
|
||||
pa.types.is_integer(t) or pa.types.is_floating(t) for t in data.schema.types
|
||||
):
|
||||
raise ValueError(
|
||||
"Features in dataset can only be integers or floating point number"
|
||||
)
|
||||
if enable_categorical:
|
||||
raise ValueError("categorical data in arrow is not supported yet.")
|
||||
if isinstance(data, Dataset):
|
||||
raise TypeError("arrow Dataset is not supported.")
|
||||
|
||||
batches = data.to_batches()
|
||||
rb_iter = iter(batches)
|
||||
it = record_batch_data_iter(rb_iter)
|
||||
next_callback = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p)(it)
|
||||
handle = ctypes.c_void_p()
|
||||
config = from_pystr_to_cstr(
|
||||
json.dumps({"missing": missing, "nthread": nthread, "nbatch": len(batches)})
|
||||
)
|
||||
_check_call(
|
||||
_LIB.XGDMatrixCreateFromArrowCallback(
|
||||
next_callback,
|
||||
config,
|
||||
ctypes.byref(handle),
|
||||
)
|
||||
)
|
||||
return handle, feature_names, feature_types
|
||||
data = cast(pa.Table, data)
|
||||
|
||||
def type_mapper(dtype: pa.DataType) -> Optional[str]:
|
||||
"""Maps pyarrow type to pandas arrow extension type."""
|
||||
if pa.types.is_int8(dtype):
|
||||
return pd.ArrowDtype(pa.int8())
|
||||
if pa.types.is_int16(dtype):
|
||||
return pd.ArrowDtype(pa.int16())
|
||||
if pa.types.is_int32(dtype):
|
||||
return pd.ArrowDtype(pa.int32())
|
||||
if pa.types.is_int64(dtype):
|
||||
return pd.ArrowDtype(pa.int64())
|
||||
if pa.types.is_uint8(dtype):
|
||||
return pd.ArrowDtype(pa.uint8())
|
||||
if pa.types.is_uint16(dtype):
|
||||
return pd.ArrowDtype(pa.uint16())
|
||||
if pa.types.is_uint32(dtype):
|
||||
return pd.ArrowDtype(pa.uint32())
|
||||
if pa.types.is_uint64(dtype):
|
||||
return pd.ArrowDtype(pa.uint64())
|
||||
if pa.types.is_float16(dtype):
|
||||
return pd.ArrowDtype(pa.float16())
|
||||
if pa.types.is_float32(dtype):
|
||||
return pd.ArrowDtype(pa.float32())
|
||||
if pa.types.is_float64(dtype):
|
||||
return pd.ArrowDtype(pa.float64())
|
||||
if pa.types.is_boolean(dtype):
|
||||
return pd.ArrowDtype(pa.bool_())
|
||||
return None
|
||||
|
||||
df = data.to_pandas(types_mapper=type_mapper)
|
||||
return df
|
||||
|
||||
|
||||
def _is_cudf_df(data: DataType) -> bool:
|
||||
@@ -1081,6 +1046,8 @@ def dispatch_data_backend(
|
||||
return _from_list(data, missing, threads, feature_names, feature_types)
|
||||
if _is_tuple(data):
|
||||
return _from_tuple(data, missing, threads, feature_names, feature_types)
|
||||
if _is_arrow(data):
|
||||
data = _arrow_transform(data)
|
||||
if _is_pandas_series(data):
|
||||
import pandas as pd
|
||||
|
||||
@@ -1114,10 +1081,6 @@ def dispatch_data_backend(
|
||||
return _from_pandas_series(
|
||||
data, missing, threads, enable_categorical, feature_names, feature_types
|
||||
)
|
||||
if _is_arrow(data):
|
||||
return _from_arrow(
|
||||
data, missing, threads, feature_names, feature_types, enable_categorical
|
||||
)
|
||||
if _has_array_protocol(data):
|
||||
array = np.asarray(data)
|
||||
return _from_numpy_array(array, missing, threads, feature_names, feature_types)
|
||||
@@ -1217,6 +1180,8 @@ def dispatch_meta_backend(
|
||||
if _is_np_array_like(data):
|
||||
_meta_from_numpy(data, name, dtype, handle)
|
||||
return
|
||||
if _is_arrow(data):
|
||||
data = _arrow_transform(data)
|
||||
if _is_pandas_df(data):
|
||||
data, _, _ = _transform_pandas_df(data, False, meta=name, meta_type=dtype)
|
||||
_meta_from_numpy(data, name, dtype, handle)
|
||||
@@ -1311,6 +1276,8 @@ def _proxy_transform(
|
||||
import pandas as pd
|
||||
|
||||
data = pd.DataFrame(data)
|
||||
if _is_arrow(data):
|
||||
data = _arrow_transform(data)
|
||||
if _is_pandas_df(data):
|
||||
arr, feature_names, feature_types = _transform_pandas_df(
|
||||
data, enable_categorical, feature_names, feature_types
|
||||
|
||||
Reference in New Issue
Block a user