Support primitive types of pyarrow-backed pandas dataframe. (#8653)
Categorical data (dictionary) is not supported at the moment.
This commit is contained in:
parent
3760cede0f
commit
1325ba9251
@ -251,7 +251,25 @@ pandas_nullable_mapper = {
|
|||||||
"boolean": "i",
|
"boolean": "i",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pandas_pyarrow_mapper = {
|
||||||
|
"int8[pyarrow]": "i",
|
||||||
|
"int16[pyarrow]": "i",
|
||||||
|
"int32[pyarrow]": "i",
|
||||||
|
"int64[pyarrow]": "i",
|
||||||
|
"uint8[pyarrow]": "i",
|
||||||
|
"uint16[pyarrow]": "i",
|
||||||
|
"uint32[pyarrow]": "i",
|
||||||
|
"uint64[pyarrow]": "i",
|
||||||
|
"float[pyarrow]": "float",
|
||||||
|
"float32[pyarrow]": "float",
|
||||||
|
"double[pyarrow]": "float",
|
||||||
|
"float64[pyarrow]": "float",
|
||||||
|
"bool[pyarrow]": "i",
|
||||||
|
}
|
||||||
|
|
||||||
_pandas_dtype_mapper.update(pandas_nullable_mapper)
|
_pandas_dtype_mapper.update(pandas_nullable_mapper)
|
||||||
|
_pandas_dtype_mapper.update(pandas_pyarrow_mapper)
|
||||||
|
|
||||||
|
|
||||||
_ENABLE_CAT_ERR = (
|
_ENABLE_CAT_ERR = (
|
||||||
"When categorical type is supplied, The experimental DMatrix parameter"
|
"When categorical type is supplied, The experimental DMatrix parameter"
|
||||||
@ -277,13 +295,14 @@ def _invalid_dataframe_dtype(data: DataType) -> None:
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
def _pandas_feature_info(
|
def pandas_feature_info(
|
||||||
data: DataFrame,
|
data: DataFrame,
|
||||||
meta: Optional[str],
|
meta: Optional[str],
|
||||||
feature_names: Optional[FeatureNames],
|
feature_names: Optional[FeatureNames],
|
||||||
feature_types: Optional[FeatureTypes],
|
feature_types: Optional[FeatureTypes],
|
||||||
enable_categorical: bool,
|
enable_categorical: bool,
|
||||||
) -> Tuple[Optional[FeatureNames], Optional[FeatureTypes]]:
|
) -> Tuple[Optional[FeatureNames], Optional[FeatureTypes]]:
|
||||||
|
"""Handle feature info for pandas dataframe."""
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pandas.api.types import is_categorical_dtype, is_sparse
|
from pandas.api.types import is_categorical_dtype, is_sparse
|
||||||
|
|
||||||
@ -302,7 +321,9 @@ def _pandas_feature_info(
|
|||||||
for dtype in data.dtypes:
|
for dtype in data.dtypes:
|
||||||
if is_sparse(dtype):
|
if is_sparse(dtype):
|
||||||
feature_types.append(_pandas_dtype_mapper[dtype.subtype.name])
|
feature_types.append(_pandas_dtype_mapper[dtype.subtype.name])
|
||||||
elif is_categorical_dtype(dtype) and enable_categorical:
|
elif (
|
||||||
|
is_categorical_dtype(dtype) or is_pa_ext_categorical_dtype(dtype)
|
||||||
|
) and enable_categorical:
|
||||||
feature_types.append(CAT_T)
|
feature_types.append(CAT_T)
|
||||||
else:
|
else:
|
||||||
feature_types.append(_pandas_dtype_mapper[dtype.name])
|
feature_types.append(_pandas_dtype_mapper[dtype.name])
|
||||||
@ -310,7 +331,7 @@ def _pandas_feature_info(
|
|||||||
|
|
||||||
|
|
||||||
def is_nullable_dtype(dtype: PandasDType) -> bool:
|
def is_nullable_dtype(dtype: PandasDType) -> bool:
|
||||||
"""Wether dtype is a pandas nullable type."""
|
"""Whether dtype is a pandas nullable type."""
|
||||||
from pandas.api.types import (
|
from pandas.api.types import (
|
||||||
is_bool_dtype,
|
is_bool_dtype,
|
||||||
is_categorical_dtype,
|
is_categorical_dtype,
|
||||||
@ -319,38 +340,63 @@ def is_nullable_dtype(dtype: PandasDType) -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
is_int = is_integer_dtype(dtype) and dtype.name in pandas_nullable_mapper
|
is_int = is_integer_dtype(dtype) and dtype.name in pandas_nullable_mapper
|
||||||
# np.bool has alias `bool`, while pd.BooleanDtype has `bzoolean`.
|
# np.bool has alias `bool`, while pd.BooleanDtype has `boolean`.
|
||||||
is_bool = is_bool_dtype(dtype) and dtype.name == "boolean"
|
is_bool = is_bool_dtype(dtype) and dtype.name == "boolean"
|
||||||
is_float = is_float_dtype(dtype) and dtype.name in pandas_nullable_mapper
|
is_float = is_float_dtype(dtype) and dtype.name in pandas_nullable_mapper
|
||||||
return is_int or is_bool or is_float or is_categorical_dtype(dtype)
|
return is_int or is_bool or is_float or is_categorical_dtype(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def is_pa_ext_dtype(dtype: Any) -> bool:
|
||||||
|
"""Return whether dtype is a pyarrow extension type for pandas"""
|
||||||
|
return hasattr(dtype, "pyarrow_dtype")
|
||||||
|
|
||||||
|
|
||||||
|
def is_pa_ext_categorical_dtype(dtype: Any) -> bool:
|
||||||
|
"""Check whether dtype is a dictionary type."""
|
||||||
|
return lazy_isinstance(
|
||||||
|
getattr(dtype, "pyarrow_dtype", None), "pyarrow.lib", "DictionaryType"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pandas_cat_null(data: DataFrame) -> DataFrame:
|
def pandas_cat_null(data: DataFrame) -> DataFrame:
|
||||||
"""Handle categorical dtype and nullable extension types from pandas."""
|
"""Handle categorical dtype and nullable extension types from pandas."""
|
||||||
|
import pandas as pd
|
||||||
from pandas.api.types import is_categorical_dtype
|
from pandas.api.types import is_categorical_dtype
|
||||||
|
|
||||||
# handle category codes and nullable.
|
# handle category codes and nullable.
|
||||||
cat_columns = []
|
cat_columns = []
|
||||||
nul_columns = []
|
nul_columns = []
|
||||||
|
# avoid an unnecessary conversion if possible
|
||||||
for col, dtype in zip(data.columns, data.dtypes):
|
for col, dtype in zip(data.columns, data.dtypes):
|
||||||
if is_categorical_dtype(dtype):
|
if is_categorical_dtype(dtype):
|
||||||
cat_columns.append(col)
|
cat_columns.append(col)
|
||||||
# avoid an unnecessary conversion if possible
|
elif is_pa_ext_categorical_dtype(dtype):
|
||||||
|
raise ValueError(
|
||||||
|
"pyarrow dictionary type is not supported. Use pandas category instead."
|
||||||
|
)
|
||||||
elif is_nullable_dtype(dtype):
|
elif is_nullable_dtype(dtype):
|
||||||
nul_columns.append(col)
|
nul_columns.append(col)
|
||||||
|
|
||||||
if cat_columns or nul_columns:
|
if cat_columns or nul_columns:
|
||||||
# Avoid transformation due to: PerformanceWarning: DataFrame is highly
|
# Avoid transformation due to: PerformanceWarning: DataFrame is highly
|
||||||
# fragmented
|
# fragmented
|
||||||
transformed = data.copy()
|
transformed = data.copy(deep=False)
|
||||||
else:
|
else:
|
||||||
transformed = data
|
transformed = data
|
||||||
|
|
||||||
|
def cat_codes(ser: pd.Series) -> pd.Series:
|
||||||
|
if is_categorical_dtype(ser.dtype):
|
||||||
|
return ser.cat.codes
|
||||||
|
assert is_pa_ext_categorical_dtype(ser.dtype)
|
||||||
|
# Not yet supported, the index is not ordered for some reason. Alternately:
|
||||||
|
# `combine_chunks().to_pandas().cat.codes`. The result is the same.
|
||||||
|
return ser.array.__arrow_array__().combine_chunks().dictionary_encode().indices
|
||||||
|
|
||||||
if cat_columns:
|
if cat_columns:
|
||||||
# DF doesn't have the cat attribute, as a result, we use apply here
|
# DF doesn't have the cat attribute, as a result, we use apply here
|
||||||
transformed[cat_columns] = (
|
transformed[cat_columns] = (
|
||||||
transformed[cat_columns]
|
transformed[cat_columns]
|
||||||
.apply(lambda x: x.cat.codes)
|
.apply(cat_codes)
|
||||||
.astype(np.float32)
|
.astype(np.float32)
|
||||||
.replace(-1.0, np.NaN)
|
.replace(-1.0, np.NaN)
|
||||||
)
|
)
|
||||||
@ -364,6 +410,29 @@ def pandas_cat_null(data: DataFrame) -> DataFrame:
|
|||||||
return transformed
|
return transformed
|
||||||
|
|
||||||
|
|
||||||
|
def pandas_ext_num_types(data: DataFrame) -> DataFrame:
|
||||||
|
"""Experimental suppport for handling pyarrow extension numeric types."""
|
||||||
|
import pandas as pd
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
for col, dtype in zip(data.columns, data.dtypes):
|
||||||
|
if not is_pa_ext_dtype(dtype):
|
||||||
|
continue
|
||||||
|
# No copy, callstack:
|
||||||
|
# pandas.core.internals.managers.SingleBlockManager.array_values()
|
||||||
|
# pandas.core.internals.blocks.EABackedBlock.values
|
||||||
|
d_array: pd.arrays.ArrowExtensionArray = data[col].array
|
||||||
|
# no copy in __arrow_array__
|
||||||
|
# ArrowExtensionArray._data is a chunked array
|
||||||
|
aa: pa.ChunkedArray = d_array.__arrow_array__()
|
||||||
|
chunk: pa.Array = aa.combine_chunks()
|
||||||
|
# Alternately, we can use chunk.buffers(), which returns a list of buffers and
|
||||||
|
# we need to concatenate them ourselves.
|
||||||
|
arr = chunk.__array__()
|
||||||
|
data[col] = arr
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _transform_pandas_df(
|
def _transform_pandas_df(
|
||||||
data: DataFrame,
|
data: DataFrame,
|
||||||
enable_categorical: bool,
|
enable_categorical: bool,
|
||||||
@ -374,19 +443,27 @@ def _transform_pandas_df(
|
|||||||
) -> Tuple[np.ndarray, Optional[FeatureNames], Optional[FeatureTypes]]:
|
) -> Tuple[np.ndarray, Optional[FeatureNames], Optional[FeatureTypes]]:
|
||||||
from pandas.api.types import is_categorical_dtype, is_sparse
|
from pandas.api.types import is_categorical_dtype, is_sparse
|
||||||
|
|
||||||
if not all(
|
pyarrow_extension = False
|
||||||
|
for dtype in data.dtypes:
|
||||||
|
if not (
|
||||||
(dtype.name in _pandas_dtype_mapper)
|
(dtype.name in _pandas_dtype_mapper)
|
||||||
or is_sparse(dtype)
|
or is_sparse(dtype)
|
||||||
or (is_categorical_dtype(dtype) and enable_categorical)
|
or (is_categorical_dtype(dtype) and enable_categorical)
|
||||||
for dtype in data.dtypes
|
or is_pa_ext_dtype(dtype)
|
||||||
):
|
):
|
||||||
_invalid_dataframe_dtype(data)
|
_invalid_dataframe_dtype(data)
|
||||||
|
if is_pa_ext_dtype(dtype):
|
||||||
|
pyarrow_extension = True
|
||||||
|
|
||||||
feature_names, feature_types = _pandas_feature_info(
|
feature_names, feature_types = pandas_feature_info(
|
||||||
data, meta, feature_names, feature_types, enable_categorical
|
data, meta, feature_names, feature_types, enable_categorical
|
||||||
)
|
)
|
||||||
|
|
||||||
transformed = pandas_cat_null(data)
|
transformed = pandas_cat_null(data)
|
||||||
|
if pyarrow_extension:
|
||||||
|
if transformed is data:
|
||||||
|
transformed = data.copy(deep=False)
|
||||||
|
transformed = pandas_ext_num_types(transformed)
|
||||||
|
|
||||||
if meta and len(data.columns) > 1 and meta not in _matrix_meta:
|
if meta and len(data.columns) > 1 and meta not in _matrix_meta:
|
||||||
raise ValueError(f"DataFrame for {meta} cannot have multiple columns")
|
raise ValueError(f"DataFrame for {meta} cannot have multiple columns")
|
||||||
@ -1192,7 +1269,10 @@ def _proxy_transform(
|
|||||||
enable_categorical: bool,
|
enable_categorical: bool,
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
Union[bool, ctypes.c_void_p, np.ndarray],
|
Union[bool, ctypes.c_void_p, np.ndarray],
|
||||||
Optional[list], Optional[FeatureNames], Optional[FeatureTypes]]:
|
Optional[list],
|
||||||
|
Optional[FeatureNames],
|
||||||
|
Optional[FeatureTypes],
|
||||||
|
]:
|
||||||
if _is_cudf_df(data) or _is_cudf_ser(data):
|
if _is_cudf_df(data) or _is_cudf_ser(data):
|
||||||
return _transform_cudf_df(
|
return _transform_cudf_df(
|
||||||
data, feature_names, feature_types, enable_categorical
|
data, feature_names, feature_types, enable_categorical
|
||||||
@ -1212,6 +1292,7 @@ def _proxy_transform(
|
|||||||
return data, None, feature_names, feature_types
|
return data, None, feature_names, feature_types
|
||||||
if _is_pandas_series(data):
|
if _is_pandas_series(data):
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
data = pd.DataFrame(data)
|
data = pd.DataFrame(data)
|
||||||
if _is_pandas_df(data):
|
if _is_pandas_df(data):
|
||||||
arr, feature_names, feature_types = _transform_pandas_df(
|
arr, feature_names, feature_types = _transform_pandas_df(
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
from typing import Any, Generator, Tuple, Union
|
from typing import Any, Generator, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from xgboost.data import pandas_pyarrow_mapper
|
||||||
|
|
||||||
|
|
||||||
def np_dtypes(
|
def np_dtypes(
|
||||||
@ -124,3 +125,56 @@ def pd_dtypes() -> Generator:
|
|||||||
orig = pd.DataFrame(data, dtype=np.bool_ if Null is None else pd.BooleanDtype())
|
orig = pd.DataFrame(data, dtype=np.bool_ if Null is None else pd.BooleanDtype())
|
||||||
df = pd.DataFrame(data, dtype=pd.BooleanDtype())
|
df = pd.DataFrame(data, dtype=pd.BooleanDtype())
|
||||||
yield orig, df
|
yield orig, df
|
||||||
|
|
||||||
|
|
||||||
|
def pd_arrow_dtypes() -> Generator:
|
||||||
|
"""Pandas DataFrame with pyarrow backed type."""
|
||||||
|
import pandas as pd
|
||||||
|
import pyarrow as pa # pylint: disable=import-error
|
||||||
|
|
||||||
|
# Integer
|
||||||
|
dtypes = pandas_pyarrow_mapper
|
||||||
|
Null: Union[float, None, Any] = np.nan
|
||||||
|
orig = pd.DataFrame(
|
||||||
|
{"f0": [1, 2, Null, 3], "f1": [4, 3, Null, 1]}, dtype=np.float32
|
||||||
|
)
|
||||||
|
# Create a dictionary-backed dataframe, enable this when the roundtrip is
|
||||||
|
# implemented in pandas/pyarrow
|
||||||
|
#
|
||||||
|
# category = pd.ArrowDtype(pa.dictionary(pa.int32(), pa.int32(), ordered=True))
|
||||||
|
# df = pd.DataFrame({"f0": [0, 2, Null, 3], "f1": [4, 3, Null, 1]}, dtype=category)
|
||||||
|
|
||||||
|
# Error:
|
||||||
|
# >>> df.astype("category")
|
||||||
|
# Function 'dictionary_encode' has no kernel matching input types
|
||||||
|
# (array[dictionary<values=int32, indices=int32, ordered=0>])
|
||||||
|
|
||||||
|
# Error:
|
||||||
|
# pd_cat_df = pd.DataFrame(
|
||||||
|
# {"f0": [0, 2, Null, 3], "f1": [4, 3, Null, 1]},
|
||||||
|
# dtype="category"
|
||||||
|
# )
|
||||||
|
# pa_catcodes = (
|
||||||
|
# df["f1"].array.__arrow_array__().combine_chunks().to_pandas().cat.codes
|
||||||
|
# )
|
||||||
|
# pd_catcodes = pd_cat_df["f1"].cat.codes
|
||||||
|
# assert pd_catcodes.equals(pa_catcodes)
|
||||||
|
|
||||||
|
for Null in (None, pd.NA):
|
||||||
|
for dtype in dtypes:
|
||||||
|
if dtype.startswith("float16") or dtype.startswith("bool"):
|
||||||
|
continue
|
||||||
|
df = pd.DataFrame(
|
||||||
|
{"f0": [1, 2, Null, 3], "f1": [4, 3, Null, 1]}, dtype=dtype
|
||||||
|
)
|
||||||
|
yield orig, df
|
||||||
|
|
||||||
|
orig = pd.DataFrame(
|
||||||
|
{"f0": [True, False, pd.NA, True], "f1": [False, True, pd.NA, True]},
|
||||||
|
dtype=pd.BooleanDtype(),
|
||||||
|
)
|
||||||
|
df = pd.DataFrame(
|
||||||
|
{"f0": [True, False, pd.NA, True], "f1": [False, True, pd.NA, True]},
|
||||||
|
dtype=pd.ArrowDtype(pa.bool_()),
|
||||||
|
)
|
||||||
|
yield orig, df
|
||||||
|
|||||||
@ -1,7 +1,9 @@
|
|||||||
|
from typing import Type
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from test_dmatrix import set_base_margin_info
|
from test_dmatrix import set_base_margin_info
|
||||||
from xgboost.testing.data import pd_dtypes
|
from xgboost.testing.data import pd_arrow_dtypes, pd_dtypes
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost import testing as tm
|
from xgboost import testing as tm
|
||||||
@ -305,9 +307,17 @@ class TestPandas:
|
|||||||
# series
|
# series
|
||||||
enable_categorical = is_categorical(df.dtype)
|
enable_categorical = is_categorical(df.dtype)
|
||||||
|
|
||||||
m_orig = DMatrixT(orig, enable_categorical=enable_categorical)
|
f0_orig = orig[orig.columns[0]] if isinstance(orig, pd.DataFrame) else orig
|
||||||
|
f0 = df[df.columns[0]] if isinstance(df, pd.DataFrame) else df
|
||||||
|
y_orig = f0_orig.astype(pd.Float32Dtype()).fillna(0)
|
||||||
|
y = f0.astype(pd.Float32Dtype()).fillna(0)
|
||||||
|
|
||||||
|
m_orig = DMatrixT(orig, enable_categorical=enable_categorical, label=y_orig)
|
||||||
# extension types
|
# extension types
|
||||||
m_etype = DMatrixT(df, enable_categorical=enable_categorical)
|
copy = df.copy()
|
||||||
|
m_etype = DMatrixT(df, enable_categorical=enable_categorical, label=y)
|
||||||
|
# no mutation
|
||||||
|
assert df.equals(copy)
|
||||||
# different from pd.BooleanDtype(), None is converted to False with bool
|
# different from pd.BooleanDtype(), None is converted to False with bool
|
||||||
if hasattr(orig.dtypes, "__iter__") and any(
|
if hasattr(orig.dtypes, "__iter__") and any(
|
||||||
dtype == "bool" for dtype in orig.dtypes
|
dtype == "bool" for dtype in orig.dtypes
|
||||||
@ -316,7 +326,32 @@ class TestPandas:
|
|||||||
else:
|
else:
|
||||||
assert tm.predictor_equal(m_orig, m_etype)
|
assert tm.predictor_equal(m_orig, m_etype)
|
||||||
|
|
||||||
|
np.testing.assert_allclose(m_orig.get_label(), m_etype.get_label())
|
||||||
|
np.testing.assert_allclose(m_etype.get_label(), y.values.astype(np.float32))
|
||||||
|
|
||||||
if isinstance(df, pd.DataFrame):
|
if isinstance(df, pd.DataFrame):
|
||||||
f0 = df["f0"]
|
f0 = df["f0"]
|
||||||
with pytest.raises(ValueError, match="Label contains NaN"):
|
with pytest.raises(ValueError, match="Label contains NaN"):
|
||||||
xgb.DMatrix(df, f0, enable_categorical=enable_categorical)
|
xgb.DMatrix(df, f0, enable_categorical=enable_categorical)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_arrow())
|
||||||
|
@pytest.mark.parametrize("DMatrixT", [xgb.DMatrix, xgb.QuantileDMatrix])
|
||||||
|
def test_pyarrow_type(self, DMatrixT: Type[xgb.DMatrix]) -> None:
|
||||||
|
for orig, df in pd_arrow_dtypes():
|
||||||
|
f0_orig: pd.Series = orig["f0"]
|
||||||
|
f0 = df["f0"]
|
||||||
|
|
||||||
|
if f0.dtype.name.startswith("bool"):
|
||||||
|
y = None
|
||||||
|
y_orig = None
|
||||||
|
else:
|
||||||
|
y_orig = f0_orig.fillna(0, inplace=False)
|
||||||
|
y = f0.fillna(0, inplace=False)
|
||||||
|
|
||||||
|
m_orig = DMatrixT(orig, enable_categorical=True, label=y_orig)
|
||||||
|
m_etype = DMatrixT(df, enable_categorical=True, label=y)
|
||||||
|
|
||||||
|
assert tm.predictor_equal(m_orig, m_etype)
|
||||||
|
if y is not None:
|
||||||
|
np.testing.assert_allclose(m_orig.get_label(), m_etype.get_label())
|
||||||
|
np.testing.assert_allclose(m_etype.get_label(), y.values)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user