Support more pandas nullable types (#8262)

- Float32/64
- Category.
This commit is contained in:
Jiaming Yuan
2022-09-27 01:59:50 +08:00
committed by GitHub
parent 1082ccd3cc
commit fcab51aa82
2 changed files with 48 additions and 7 deletions

View File

@@ -231,13 +231,15 @@ _pandas_dtype_mapper = {
"Int16": "int",
"Int32": "int",
"Int64": "int",
"Float32": "float",
"Float64": "float",
"boolean": "i",
}
_ENABLE_CAT_ERR = (
"When categorical type is supplied, DMatrix parameter `enable_categorical` must "
"be set to `True`."
"When categorical type is supplied, The experimental DMatrix parameter"
"`enable_categorical` must be set to `True`."
)
@@ -246,7 +248,7 @@ def _invalid_dataframe_dtype(data: DataType) -> None:
# cudf series doesn't have `dtypes`.
if hasattr(data, "dtypes") and hasattr(data.dtypes, "__iter__"):
bad_fields = [
str(data.columns[i])
f"{data.columns[i]}: {dtype}"
for i, dtype in enumerate(data.dtypes)
if dtype.name not in _pandas_dtype_mapper
]
@@ -296,13 +298,20 @@ def _pandas_feature_info(
def is_nullable_dtype(dtype: PandasDType) -> bool:
"""Wether dtype is a pandas nullable type."""
from pandas.api.types import is_integer_dtype, is_bool_dtype
from pandas.api.types import (
is_integer_dtype,
is_bool_dtype,
is_float_dtype,
is_categorical_dtype,
)
# dtype: pd.core.arrays.numeric.NumericDtype
nullable_alias = {"Int16", "Int32", "Int64"}
nullable_alias = {"Int16", "Int32", "Int64", "Float32", "Float64", "category"}
is_int = is_integer_dtype(dtype) and dtype.name in nullable_alias
# np.bool has alias `bool`, while pd.BooleanDtype has `bzoolean`.
is_bool = is_bool_dtype(dtype) and dtype.name == "boolean"
return is_int or is_bool
is_float = is_float_dtype(dtype) and dtype.name in nullable_alias
return is_int or is_bool or is_float or is_categorical_dtype(dtype)
def _pandas_cat_null(data: DataFrame) -> DataFrame:
@@ -353,7 +362,7 @@ def _transform_pandas_df(
if not all(
dtype.name in _pandas_dtype_mapper
or is_sparse(dtype)
or is_nullable_dtype(dtype)
or (is_nullable_dtype(dtype) and not is_categorical_dtype(dtype))
or (is_categorical_dtype(dtype) and enable_categorical)
for dtype in data.dtypes
):