From 24789429fdf4e1e982f0af6f2f2225382c42b756 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 26 Jan 2022 18:20:10 +0800 Subject: [PATCH] Support latest pandas Index type. (#7595) --- python-package/xgboost/compat.py | 3 +-- python-package/xgboost/data.py | 2 +- tests/python/test_with_pandas.py | 6 ++++++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/python-package/xgboost/compat.py b/python-package/xgboost/compat.py index 954d04aea..256a77adf 100644 --- a/python-package/xgboost/compat.py +++ b/python-package/xgboost/compat.py @@ -33,14 +33,13 @@ def lazy_isinstance(instance, module, name): # pandas try: from pandas import DataFrame, Series - from pandas import MultiIndex, Int64Index + from pandas import MultiIndex from pandas import concat as pandas_concat PANDAS_INSTALLED = True except ImportError: MultiIndex = object - Int64Index = object DataFrame: Any = object Series = object pandas_concat = None diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index bfc4cf2d8..f378e0e26 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -264,7 +264,7 @@ def _transform_pandas_df( if feature_names is None and meta is None: if isinstance(data.columns, pd.MultiIndex): feature_names = [" ".join([str(x) for x in i]) for i in data.columns] - elif isinstance(data.columns, (pd.Int64Index, pd.RangeIndex)): + elif isinstance(data.columns, (pd.Index, pd.RangeIndex)): feature_names = list(map(str, data.columns)) else: feature_names = data.columns.format() diff --git a/tests/python/test_with_pandas.py b/tests/python/test_with_pandas.py index 888784581..b8191efcb 100644 --- a/tests/python/test_with_pandas.py +++ b/tests/python/test_with_pandas.py @@ -109,6 +109,12 @@ class TestPandas: assert dm.num_row() == 2 assert dm.num_col() == 6 + # test Index as columns + df = pd.DataFrame([[1, 1.1], [2, 2.2]], columns=pd.Index([1, 2])) + print(df.columns, isinstance(df.columns, pd.Index)) + Xy = xgb.DMatrix(df) + np.testing.assert_equal(np.array(Xy.feature_names), np.array(["1", "2"])) + def test_slice(self): rng = np.random.RandomState(1994) rows = 100