Support pandas SparseArray. (#5431)
This commit is contained in:
parent
3cf665d3ec
commit
abca9908ba
@ -246,12 +246,13 @@ def _has_cuda_array_interface(data):
|
|||||||
def _maybe_pandas_data(data, feature_names, feature_types,
|
def _maybe_pandas_data(data, feature_names, feature_types,
|
||||||
meta=None, meta_type=None):
|
meta=None, meta_type=None):
|
||||||
"""Extract internal data from pd.DataFrame for DMatrix data"""
|
"""Extract internal data from pd.DataFrame for DMatrix data"""
|
||||||
|
|
||||||
if not (PANDAS_INSTALLED and isinstance(data, DataFrame)):
|
if not (PANDAS_INSTALLED and isinstance(data, DataFrame)):
|
||||||
return data, feature_names, feature_types
|
return data, feature_names, feature_types
|
||||||
|
from pandas.api.types import is_sparse
|
||||||
|
|
||||||
data_dtypes = data.dtypes
|
data_dtypes = data.dtypes
|
||||||
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in data_dtypes):
|
if not all(dtype.name in PANDAS_DTYPE_MAPPER or is_sparse(dtype)
|
||||||
|
for dtype in data_dtypes):
|
||||||
bad_fields = [
|
bad_fields = [
|
||||||
str(data.columns[i]) for i, dtype in enumerate(data_dtypes)
|
str(data.columns[i]) for i, dtype in enumerate(data_dtypes)
|
||||||
if dtype.name not in PANDAS_DTYPE_MAPPER
|
if dtype.name not in PANDAS_DTYPE_MAPPER
|
||||||
@ -272,9 +273,12 @@ def _maybe_pandas_data(data, feature_names, feature_types,
|
|||||||
feature_names = data.columns.format()
|
feature_names = data.columns.format()
|
||||||
|
|
||||||
if feature_types is None and meta is None:
|
if feature_types is None and meta is None:
|
||||||
feature_types = [
|
feature_types = []
|
||||||
PANDAS_DTYPE_MAPPER[dtype.name] for dtype in data_dtypes
|
for dtype in data_dtypes:
|
||||||
]
|
if is_sparse(dtype):
|
||||||
|
feature_types.append(PANDAS_DTYPE_MAPPER[dtype.subtype.name])
|
||||||
|
else:
|
||||||
|
feature_types.append(PANDAS_DTYPE_MAPPER[dtype.name])
|
||||||
|
|
||||||
if meta and len(data.columns) > 1:
|
if meta and len(data.columns) > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@ -109,6 +109,22 @@ class TestPandas(unittest.TestCase):
|
|||||||
assert dm.num_row() == 2
|
assert dm.num_row() == 2
|
||||||
assert dm.num_col() == 6
|
assert dm.num_col() == 6
|
||||||
|
|
||||||
|
def test_pandas_sparse(self):
|
||||||
|
import pandas as pd
|
||||||
|
rows = 100
|
||||||
|
X = pd.DataFrame(
|
||||||
|
{"A": pd.SparseArray(np.random.randint(0, 10, size=rows)),
|
||||||
|
"B": pd.SparseArray(np.random.randn(rows)),
|
||||||
|
"C": pd.SparseArray(np.random.permutation(
|
||||||
|
[True, False] * (rows // 2)))}
|
||||||
|
)
|
||||||
|
y = pd.Series(pd.SparseArray(np.random.randn(rows)))
|
||||||
|
dtrain = xgb.DMatrix(X, y)
|
||||||
|
booster = xgb.train({}, dtrain, num_boost_round=4)
|
||||||
|
predt_sparse = booster.predict(xgb.DMatrix(X))
|
||||||
|
predt_dense = booster.predict(xgb.DMatrix(X.sparse.to_dense()))
|
||||||
|
np.testing.assert_allclose(predt_sparse, predt_dense)
|
||||||
|
|
||||||
def test_pandas_label(self):
|
def test_pandas_label(self):
|
||||||
# label must be a single column
|
# label must be a single column
|
||||||
df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
|
df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user