Fix issue 2670 (#2671)

* fix issue 2670

* add python<3.6 compatibility

* fix Index

* fix Index/MultiIndex

* fix lint

* fix W0622

really nonsense

* fix lambda

* Trigger Travis

* add test for MultiIndex

* remove tailing whitespace
This commit is contained in:
Icyblade Dai 2017-09-20 03:49:41 +08:00 committed by Yuan (Terry) Tang
parent ee80f348de
commit 0e85b30fdd
2 changed files with 24 additions and 1 deletions

View File

@ -184,6 +184,12 @@ Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields)) raise ValueError(msg + ', '.join(bad_fields))
if feature_names is None: if feature_names is None:
if hasattr(data.columns, 'to_frame'): # MultiIndex
feature_names = [
' '.join(map(str, i))
for i in data.columns
]
else:
feature_names = data.columns.format() feature_names = data.columns.format()
if feature_types is None: if feature_types is None:

View File

@ -79,6 +79,23 @@ class TestPandas(unittest.TestCase):
assert dm.num_row() == 3 assert dm.num_row() == 3
assert dm.num_col() == 2 assert dm.num_col() == 2
# test MultiIndex as columns
df = pd.DataFrame(
[
(1, 2, 3, 4, 5, 6),
(6, 5, 4, 3, 2, 1)
],
columns=pd.MultiIndex.from_tuples((
('a', 1), ('a', 2), ('a', 3),
('b', 1), ('b', 2), ('b', 3),
))
)
dm = xgb.DMatrix(df)
assert dm.feature_names == ['a 1', 'a 2', 'a 3', 'b 1', 'b 2', 'b 3']
assert dm.feature_types == ['int', 'int', 'int', 'int', 'int', 'int']
assert dm.num_row() == 2
assert dm.num_col() == 6
def test_pandas_label(self): def test_pandas_label(self):
# label must be a single column # label must be a single column
df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]}) df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})