Support more pandas nullable types (#8262)

- Float32/64
- Category.
This commit is contained in:
Jiaming Yuan 2022-09-27 01:59:50 +08:00 committed by GitHub
parent 1082ccd3cc
commit fcab51aa82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 7 deletions

View File

@ -231,13 +231,15 @@ _pandas_dtype_mapper = {
"Int16": "int",
"Int32": "int",
"Int64": "int",
"Float32": "float",
"Float64": "float",
"boolean": "i",
}
_ENABLE_CAT_ERR = (
"When categorical type is supplied, DMatrix parameter `enable_categorical` must "
"be set to `True`."
"When categorical type is supplied, The experimental DMatrix parameter"
"`enable_categorical` must be set to `True`."
)
@ -246,7 +248,7 @@ def _invalid_dataframe_dtype(data: DataType) -> None:
# cudf series doesn't have `dtypes`.
if hasattr(data, "dtypes") and hasattr(data.dtypes, "__iter__"):
bad_fields = [
str(data.columns[i])
f"{data.columns[i]}: {dtype}"
for i, dtype in enumerate(data.dtypes)
if dtype.name not in _pandas_dtype_mapper
]
@ -296,13 +298,20 @@ def _pandas_feature_info(
def is_nullable_dtype(dtype: PandasDType) -> bool:
"""Wether dtype is a pandas nullable type."""
from pandas.api.types import is_integer_dtype, is_bool_dtype
from pandas.api.types import (
is_integer_dtype,
is_bool_dtype,
is_float_dtype,
is_categorical_dtype,
)
# dtype: pd.core.arrays.numeric.NumericDtype
nullable_alias = {"Int16", "Int32", "Int64"}
nullable_alias = {"Int16", "Int32", "Int64", "Float32", "Float64", "category"}
is_int = is_integer_dtype(dtype) and dtype.name in nullable_alias
# np.bool has alias `bool`, while pd.BooleanDtype has `bzoolean`.
is_bool = is_bool_dtype(dtype) and dtype.name == "boolean"
return is_int or is_bool
is_float = is_float_dtype(dtype) and dtype.name in nullable_alias
return is_int or is_bool or is_float or is_categorical_dtype(dtype)
def _pandas_cat_null(data: DataFrame) -> DataFrame:
@ -353,7 +362,7 @@ def _transform_pandas_df(
if not all(
dtype.name in _pandas_dtype_mapper
or is_sparse(dtype)
or is_nullable_dtype(dtype)
or (is_nullable_dtype(dtype) and not is_categorical_dtype(dtype))
or (is_categorical_dtype(dtype) and enable_categorical)
for dtype in data.dtypes
):

View File

@ -330,3 +330,35 @@ class TestPandas:
b0 = test_bool(pd.BooleanDtype())
b1 = test_bool(bool)
assert b0 != b1 # None is converted to False with np.bool
data = {"f0": [1.0, 2.0, None, 3.0], "f1": [3.0, 2.0, None, 1.0]}
arr = np.array([data["f0"], data["f1"]]).T
Xy = xgb.DMatrix(arr, y)
Xy.feature_types = None
Xy.feature_names = None
from_np = to_bytes(Xy)
def test_float(dtype) -> bytes:
arr = pd.DataFrame(data, dtype=dtype)
Xy = xgb.DMatrix(arr, y)
Xy.feature_types = None
Xy.feature_names = None
return to_bytes(Xy)
b0 = test_float(pd.Float64Dtype())
b1 = test_float(float)
assert b0 == b1 # both are converted to NaN
assert b0 == from_np
def test_cat(dtype) -> bytes:
arr = pd.DataFrame(data, dtype=dtype)
if dtype is None:
arr = arr.astype("category")
Xy = xgb.DMatrix(arr, y, enable_categorical=True)
Xy.feature_types = None
return to_bytes(Xy)
b0 = test_cat(pd.CategoricalDtype())
b1 = test_cat(None)
assert b0 == b1