From 0e85b30fdd8d99a39171ec1e7718c4da7ad30a4d Mon Sep 17 00:00:00 2001 From: Icyblade Dai Date: Wed, 20 Sep 2017 03:49:41 +0800 Subject: [PATCH] Fix issue 2670 (#2671) * fix issue 2670 * add python<3.6 compatibility * fix Index * fix Index/MultiIndex * fix lint * fix W0622 really nonsense * fix lambda * Trigger Travis * add test for MultiIndex * remove tailing whitespace --- python-package/xgboost/core.py | 8 +++++++- tests/python/test_with_pandas.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 3165a2359..8ae45c5bc 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -184,7 +184,13 @@ Did not expect the data types in fields """ raise ValueError(msg + ', '.join(bad_fields)) if feature_names is None: - feature_names = data.columns.format() + if hasattr(data.columns, 'to_frame'): # MultiIndex + feature_names = [ + ' '.join(map(str, i)) + for i in data.columns + ] + else: + feature_names = data.columns.format() if feature_types is None: feature_types = [PANDAS_DTYPE_MAPPER[dtype.name] for dtype in data_dtypes] diff --git a/tests/python/test_with_pandas.py b/tests/python/test_with_pandas.py index 0cb3045db..3bb26c12d 100644 --- a/tests/python/test_with_pandas.py +++ b/tests/python/test_with_pandas.py @@ -79,6 +79,23 @@ class TestPandas(unittest.TestCase): assert dm.num_row() == 3 assert dm.num_col() == 2 + # test MultiIndex as columns + df = pd.DataFrame( + [ + (1, 2, 3, 4, 5, 6), + (6, 5, 4, 3, 2, 1) + ], + columns=pd.MultiIndex.from_tuples(( + ('a', 1), ('a', 2), ('a', 3), + ('b', 1), ('b', 2), ('b', 3), + )) + ) + dm = xgb.DMatrix(df) + assert dm.feature_names == ['a 1', 'a 2', 'a 3', 'b 1', 'b 2', 'b 3'] + assert dm.feature_types == ['int', 'int', 'int', 'int', 'int', 'int'] + assert dm.num_row() == 2 + assert dm.num_col() == 6 + def test_pandas_label(self): # label must be a single column df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})