From 1325ba92517cc9edbda8400f01e3e844a2ec237b Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 30 Jan 2023 17:53:29 +0800 Subject: [PATCH] Support primitive types of pyarrow-backed pandas dataframe. (#8653) Categorical data (dictionary) is not supported at the moment. --- python-package/xgboost/data.py | 113 +++++++++++++++++++++---- python-package/xgboost/testing/data.py | 54 ++++++++++++ tests/python/test_with_pandas.py | 41 ++++++++- 3 files changed, 189 insertions(+), 19 deletions(-) diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 68650df6f..db7fdd960 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -251,7 +251,25 @@ pandas_nullable_mapper = { "boolean": "i", } +pandas_pyarrow_mapper = { + "int8[pyarrow]": "i", + "int16[pyarrow]": "i", + "int32[pyarrow]": "i", + "int64[pyarrow]": "i", + "uint8[pyarrow]": "i", + "uint16[pyarrow]": "i", + "uint32[pyarrow]": "i", + "uint64[pyarrow]": "i", + "float[pyarrow]": "float", + "float32[pyarrow]": "float", + "double[pyarrow]": "float", + "float64[pyarrow]": "float", + "bool[pyarrow]": "i", +} + _pandas_dtype_mapper.update(pandas_nullable_mapper) +_pandas_dtype_mapper.update(pandas_pyarrow_mapper) + _ENABLE_CAT_ERR = ( "When categorical type is supplied, The experimental DMatrix parameter" @@ -277,13 +295,14 @@ def _invalid_dataframe_dtype(data: DataType) -> None: raise ValueError(msg) -def _pandas_feature_info( +def pandas_feature_info( data: DataFrame, meta: Optional[str], feature_names: Optional[FeatureNames], feature_types: Optional[FeatureTypes], enable_categorical: bool, ) -> Tuple[Optional[FeatureNames], Optional[FeatureTypes]]: + """Handle feature info for pandas dataframe.""" import pandas as pd from pandas.api.types import is_categorical_dtype, is_sparse @@ -302,7 +321,9 @@ def _pandas_feature_info( for dtype in data.dtypes: if is_sparse(dtype): feature_types.append(_pandas_dtype_mapper[dtype.subtype.name]) - elif is_categorical_dtype(dtype) and enable_categorical: + elif ( + is_categorical_dtype(dtype) or is_pa_ext_categorical_dtype(dtype) + ) and enable_categorical: feature_types.append(CAT_T) else: feature_types.append(_pandas_dtype_mapper[dtype.name]) @@ -310,7 +331,7 @@ def _pandas_feature_info( def is_nullable_dtype(dtype: PandasDType) -> bool: - """Wether dtype is a pandas nullable type.""" + """Whether dtype is a pandas nullable type.""" from pandas.api.types import ( is_bool_dtype, is_categorical_dtype, @@ -319,38 +340,63 @@ def is_nullable_dtype(dtype: PandasDType) -> bool: ) is_int = is_integer_dtype(dtype) and dtype.name in pandas_nullable_mapper - # np.bool has alias `bool`, while pd.BooleanDtype has `bzoolean`. + # np.bool has alias `bool`, while pd.BooleanDtype has `boolean`. is_bool = is_bool_dtype(dtype) and dtype.name == "boolean" is_float = is_float_dtype(dtype) and dtype.name in pandas_nullable_mapper return is_int or is_bool or is_float or is_categorical_dtype(dtype) +def is_pa_ext_dtype(dtype: Any) -> bool: + """Return whether dtype is a pyarrow extension type for pandas""" + return hasattr(dtype, "pyarrow_dtype") + + +def is_pa_ext_categorical_dtype(dtype: Any) -> bool: + """Check whether dtype is a dictionary type.""" + return lazy_isinstance( + getattr(dtype, "pyarrow_dtype", None), "pyarrow.lib", "DictionaryType" + ) + + def pandas_cat_null(data: DataFrame) -> DataFrame: """Handle categorical dtype and nullable extension types from pandas.""" + import pandas as pd from pandas.api.types import is_categorical_dtype # handle category codes and nullable. cat_columns = [] nul_columns = [] + # avoid an unnecessary conversion if possible for col, dtype in zip(data.columns, data.dtypes): if is_categorical_dtype(dtype): cat_columns.append(col) - # avoid an unnecessary conversion if possible + elif is_pa_ext_categorical_dtype(dtype): + raise ValueError( + "pyarrow dictionary type is not supported. Use pandas category instead." + ) elif is_nullable_dtype(dtype): nul_columns.append(col) if cat_columns or nul_columns: # Avoid transformation due to: PerformanceWarning: DataFrame is highly # fragmented - transformed = data.copy() + transformed = data.copy(deep=False) else: transformed = data + def cat_codes(ser: pd.Series) -> pd.Series: + if is_categorical_dtype(ser.dtype): + return ser.cat.codes + assert is_pa_ext_categorical_dtype(ser.dtype) + # Not yet supported, the index is not ordered for some reason. Alternately: + # `combine_chunks().to_pandas().cat.codes`. The result is the same. + return ser.array.__arrow_array__().combine_chunks().dictionary_encode().indices + if cat_columns: # DF doesn't have the cat attribute, as a result, we use apply here transformed[cat_columns] = ( transformed[cat_columns] - .apply(lambda x: x.cat.codes) + .apply(cat_codes) .astype(np.float32) .replace(-1.0, np.NaN) ) @@ -364,6 +410,29 @@ def pandas_cat_null(data: DataFrame) -> DataFrame: return transformed +def pandas_ext_num_types(data: DataFrame) -> DataFrame: + """Experimental suppport for handling pyarrow extension numeric types.""" + import pandas as pd + import pyarrow as pa + + for col, dtype in zip(data.columns, data.dtypes): + if not is_pa_ext_dtype(dtype): + continue + # No copy, callstack: + # pandas.core.internals.managers.SingleBlockManager.array_values() + # pandas.core.internals.blocks.EABackedBlock.values + d_array: pd.arrays.ArrowExtensionArray = data[col].array + # no copy in __arrow_array__ + # ArrowExtensionArray._data is a chunked array + aa: pa.ChunkedArray = d_array.__arrow_array__() + chunk: pa.Array = aa.combine_chunks() + # Alternately, we can use chunk.buffers(), which returns a list of buffers and + # we need to concatenate them ourselves. + arr = chunk.__array__() + data[col] = arr + return data + + def _transform_pandas_df( data: DataFrame, enable_categorical: bool, @@ -374,19 +443,27 @@ def _transform_pandas_df( ) -> Tuple[np.ndarray, Optional[FeatureNames], Optional[FeatureTypes]]: from pandas.api.types import is_categorical_dtype, is_sparse - if not all( - (dtype.name in _pandas_dtype_mapper) - or is_sparse(dtype) - or (is_categorical_dtype(dtype) and enable_categorical) - for dtype in data.dtypes - ): - _invalid_dataframe_dtype(data) + pyarrow_extension = False + for dtype in data.dtypes: + if not ( + (dtype.name in _pandas_dtype_mapper) + or is_sparse(dtype) + or (is_categorical_dtype(dtype) and enable_categorical) + or is_pa_ext_dtype(dtype) + ): + _invalid_dataframe_dtype(data) + if is_pa_ext_dtype(dtype): + pyarrow_extension = True - feature_names, feature_types = _pandas_feature_info( + feature_names, feature_types = pandas_feature_info( data, meta, feature_names, feature_types, enable_categorical ) transformed = pandas_cat_null(data) + if pyarrow_extension: + if transformed is data: + transformed = data.copy(deep=False) + transformed = pandas_ext_num_types(transformed) if meta and len(data.columns) > 1 and meta not in _matrix_meta: raise ValueError(f"DataFrame for {meta} cannot have multiple columns") @@ -1192,7 +1269,10 @@ def _proxy_transform( enable_categorical: bool, ) -> Tuple[ Union[bool, ctypes.c_void_p, np.ndarray], - Optional[list], Optional[FeatureNames], Optional[FeatureTypes]]: + Optional[list], + Optional[FeatureNames], + Optional[FeatureTypes], +]: if _is_cudf_df(data) or _is_cudf_ser(data): return _transform_cudf_df( data, feature_names, feature_types, enable_categorical @@ -1212,6 +1292,7 @@ def _proxy_transform( return data, None, feature_names, feature_types if _is_pandas_series(data): import pandas as pd + data = pd.DataFrame(data) if _is_pandas_df(data): arr, feature_names, feature_types = _transform_pandas_df( diff --git a/python-package/xgboost/testing/data.py b/python-package/xgboost/testing/data.py index b6f47ce5d..791ffd7ec 100644 --- a/python-package/xgboost/testing/data.py +++ b/python-package/xgboost/testing/data.py @@ -2,6 +2,7 @@ from typing import Any, Generator, Tuple, Union import numpy as np +from xgboost.data import pandas_pyarrow_mapper def np_dtypes( @@ -124,3 +125,56 @@ def pd_dtypes() -> Generator: orig = pd.DataFrame(data, dtype=np.bool_ if Null is None else pd.BooleanDtype()) df = pd.DataFrame(data, dtype=pd.BooleanDtype()) yield orig, df + + +def pd_arrow_dtypes() -> Generator: + """Pandas DataFrame with pyarrow backed type.""" + import pandas as pd + import pyarrow as pa # pylint: disable=import-error + + # Integer + dtypes = pandas_pyarrow_mapper + Null: Union[float, None, Any] = np.nan + orig = pd.DataFrame( + {"f0": [1, 2, Null, 3], "f1": [4, 3, Null, 1]}, dtype=np.float32 + ) + # Create a dictionary-backed dataframe, enable this when the roundtrip is + # implemented in pandas/pyarrow + # + # category = pd.ArrowDtype(pa.dictionary(pa.int32(), pa.int32(), ordered=True)) + # df = pd.DataFrame({"f0": [0, 2, Null, 3], "f1": [4, 3, Null, 1]}, dtype=category) + + # Error: + # >>> df.astype("category") + # Function 'dictionary_encode' has no kernel matching input types + # (array[dictionary]) + + # Error: + # pd_cat_df = pd.DataFrame( + # {"f0": [0, 2, Null, 3], "f1": [4, 3, Null, 1]}, + # dtype="category" + # ) + # pa_catcodes = ( + # df["f1"].array.__arrow_array__().combine_chunks().to_pandas().cat.codes + # ) + # pd_catcodes = pd_cat_df["f1"].cat.codes + # assert pd_catcodes.equals(pa_catcodes) + + for Null in (None, pd.NA): + for dtype in dtypes: + if dtype.startswith("float16") or dtype.startswith("bool"): + continue + df = pd.DataFrame( + {"f0": [1, 2, Null, 3], "f1": [4, 3, Null, 1]}, dtype=dtype + ) + yield orig, df + + orig = pd.DataFrame( + {"f0": [True, False, pd.NA, True], "f1": [False, True, pd.NA, True]}, + dtype=pd.BooleanDtype(), + ) + df = pd.DataFrame( + {"f0": [True, False, pd.NA, True], "f1": [False, True, pd.NA, True]}, + dtype=pd.ArrowDtype(pa.bool_()), + ) + yield orig, df diff --git a/tests/python/test_with_pandas.py b/tests/python/test_with_pandas.py index ff2c2e6eb..99b34c336 100644 --- a/tests/python/test_with_pandas.py +++ b/tests/python/test_with_pandas.py @@ -1,7 +1,9 @@ +from typing import Type + import numpy as np import pytest from test_dmatrix import set_base_margin_info -from xgboost.testing.data import pd_dtypes +from xgboost.testing.data import pd_arrow_dtypes, pd_dtypes import xgboost as xgb from xgboost import testing as tm @@ -305,9 +307,17 @@ class TestPandas: # series enable_categorical = is_categorical(df.dtype) - m_orig = DMatrixT(orig, enable_categorical=enable_categorical) + f0_orig = orig[orig.columns[0]] if isinstance(orig, pd.DataFrame) else orig + f0 = df[df.columns[0]] if isinstance(df, pd.DataFrame) else df + y_orig = f0_orig.astype(pd.Float32Dtype()).fillna(0) + y = f0.astype(pd.Float32Dtype()).fillna(0) + + m_orig = DMatrixT(orig, enable_categorical=enable_categorical, label=y_orig) # extension types - m_etype = DMatrixT(df, enable_categorical=enable_categorical) + copy = df.copy() + m_etype = DMatrixT(df, enable_categorical=enable_categorical, label=y) + # no mutation + assert df.equals(copy) # different from pd.BooleanDtype(), None is converted to False with bool if hasattr(orig.dtypes, "__iter__") and any( dtype == "bool" for dtype in orig.dtypes @@ -316,7 +326,32 @@ class TestPandas: else: assert tm.predictor_equal(m_orig, m_etype) + np.testing.assert_allclose(m_orig.get_label(), m_etype.get_label()) + np.testing.assert_allclose(m_etype.get_label(), y.values.astype(np.float32)) + if isinstance(df, pd.DataFrame): f0 = df["f0"] with pytest.raises(ValueError, match="Label contains NaN"): xgb.DMatrix(df, f0, enable_categorical=enable_categorical) + + @pytest.mark.skipif(**tm.no_arrow()) + @pytest.mark.parametrize("DMatrixT", [xgb.DMatrix, xgb.QuantileDMatrix]) + def test_pyarrow_type(self, DMatrixT: Type[xgb.DMatrix]) -> None: + for orig, df in pd_arrow_dtypes(): + f0_orig: pd.Series = orig["f0"] + f0 = df["f0"] + + if f0.dtype.name.startswith("bool"): + y = None + y_orig = None + else: + y_orig = f0_orig.fillna(0, inplace=False) + y = f0.fillna(0, inplace=False) + + m_orig = DMatrixT(orig, enable_categorical=True, label=y_orig) + m_etype = DMatrixT(df, enable_categorical=True, label=y) + + assert tm.predictor_equal(m_orig, m_etype) + if y is not None: + np.testing.assert_allclose(m_orig.get_label(), m_etype.get_label()) + np.testing.assert_allclose(m_etype.get_label(), y.values)