Support pandas nullable types. (#7760)
This commit is contained in:
parent
d4796482b5
commit
9150fdbd4d
@ -220,6 +220,11 @@ _pandas_dtype_mapper = {
|
||||
'float32': 'float',
|
||||
'float64': 'float',
|
||||
'bool': 'i',
|
||||
# nullable types
|
||||
"Int16": "int",
|
||||
"Int32": "int",
|
||||
"Int64": "int",
|
||||
"boolean": "i",
|
||||
}
|
||||
|
||||
|
||||
@ -242,6 +247,7 @@ be set to `True`.""" + err
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
# pylint: disable=too-many-locals
|
||||
def _transform_pandas_df(
|
||||
data: DataFrame,
|
||||
enable_categorical: bool,
|
||||
@ -251,11 +257,26 @@ def _transform_pandas_df(
|
||||
meta_type: Optional[str] = None,
|
||||
) -> Tuple[np.ndarray, FeatureNames, Optional[List[str]]]:
|
||||
import pandas as pd
|
||||
from pandas.api.types import is_sparse, is_categorical_dtype
|
||||
from pandas.api.types import (
|
||||
is_sparse,
|
||||
is_categorical_dtype,
|
||||
is_integer_dtype,
|
||||
is_bool_dtype,
|
||||
)
|
||||
|
||||
nullable_alias = {"Int16", "Int32", "Int64"}
|
||||
|
||||
# dtype: pd.core.arrays.numeric.NumericDtype
|
||||
def is_nullable_dtype(dtype: Any) -> bool:
|
||||
is_int = is_integer_dtype(dtype) and dtype.name in nullable_alias
|
||||
# np.bool has alias `bool`, while pd.BooleanDtype has `boolean`.
|
||||
is_bool = is_bool_dtype(dtype) and dtype.name == "boolean"
|
||||
return is_int or is_bool
|
||||
|
||||
if not all(
|
||||
dtype.name in _pandas_dtype_mapper
|
||||
or is_sparse(dtype)
|
||||
or is_nullable_dtype(dtype)
|
||||
or (is_categorical_dtype(dtype) and enable_categorical)
|
||||
for dtype in data.dtypes
|
||||
):
|
||||
@ -284,7 +305,9 @@ def _transform_pandas_df(
|
||||
# handle category codes.
|
||||
transformed = pd.DataFrame()
|
||||
# Avoid transformation due to: PerformanceWarning: DataFrame is highly fragmented
|
||||
if enable_categorical and any(is_categorical_dtype(dtype) for dtype in data.dtypes):
|
||||
if (
|
||||
enable_categorical and any(is_categorical_dtype(dtype) for dtype in data.dtypes)
|
||||
) or any(is_nullable_dtype(dtype) for dtype in data.dtypes):
|
||||
for i, dtype in enumerate(data.dtypes):
|
||||
if is_categorical_dtype(dtype):
|
||||
# pandas uses -1 as default missing value for categorical data
|
||||
@ -293,6 +316,9 @@ def _transform_pandas_df(
|
||||
.cat.codes.astype(np.float32)
|
||||
.replace(-1.0, np.NaN)
|
||||
)
|
||||
elif is_nullable_dtype(dtype):
|
||||
# Converts integer <NA> to float NaN
|
||||
transformed[data.columns[i]] = data[data.columns[i]].astype(np.float32)
|
||||
else:
|
||||
transformed[data.columns[i]] = data[data.columns[i]]
|
||||
else:
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import testing as tm
|
||||
@ -293,3 +294,39 @@ class TestPandas:
|
||||
assert isinstance(params, list)
|
||||
assert 'auc' not in cv.columns[0]
|
||||
assert 'error' in cv.columns[0]
|
||||
|
||||
def test_nullable_type(self):
|
||||
y = np.random.default_rng(0).random(4)
|
||||
|
||||
def to_bytes(Xy: xgb.DMatrix) -> bytes:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "Xy.dmatrix")
|
||||
Xy.save_binary(path)
|
||||
with open(path, "rb") as fd:
|
||||
result = fd.read()
|
||||
return result
|
||||
|
||||
def test_int(dtype) -> bytes:
|
||||
arr = pd.DataFrame(
|
||||
{"f0": [1, 2, None, 3], "f1": [4, 3, None, 1]}, dtype=dtype
|
||||
)
|
||||
Xy = xgb.DMatrix(arr, y)
|
||||
Xy.feature_types = None
|
||||
return to_bytes(Xy)
|
||||
|
||||
b0 = test_int(np.float32)
|
||||
b1 = test_int(pd.Int16Dtype())
|
||||
assert b0 == b1
|
||||
|
||||
def test_bool(dtype) -> bytes:
|
||||
arr = pd.DataFrame(
|
||||
{"f0": [True, False, None, True], "f1": [False, True, None, True]},
|
||||
dtype=dtype,
|
||||
)
|
||||
Xy = xgb.DMatrix(arr, y)
|
||||
Xy.feature_types = None
|
||||
return to_bytes(Xy)
|
||||
|
||||
b0 = test_bool(pd.BooleanDtype())
|
||||
b1 = test_bool(np.bool)
|
||||
assert b0 != b1 # None is converted to False with np.bool
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user