Fix issue 2670 (#2671)
* fix issue 2670 * add python<3.6 compatibility * fix Index * fix Index/MultiIndex * fix lint * fix W0622 really nonsense * fix lambda * Trigger Travis * add test for MultiIndex * remove tailing whitespace
This commit is contained in:
parent
ee80f348de
commit
0e85b30fdd
@ -184,6 +184,12 @@ Did not expect the data types in fields """
|
||||
raise ValueError(msg + ', '.join(bad_fields))
|
||||
|
||||
if feature_names is None:
|
||||
if hasattr(data.columns, 'to_frame'): # MultiIndex
|
||||
feature_names = [
|
||||
' '.join(map(str, i))
|
||||
for i in data.columns
|
||||
]
|
||||
else:
|
||||
feature_names = data.columns.format()
|
||||
|
||||
if feature_types is None:
|
||||
|
||||
@ -79,6 +79,23 @@ class TestPandas(unittest.TestCase):
|
||||
assert dm.num_row() == 3
|
||||
assert dm.num_col() == 2
|
||||
|
||||
# test MultiIndex as columns
|
||||
df = pd.DataFrame(
|
||||
[
|
||||
(1, 2, 3, 4, 5, 6),
|
||||
(6, 5, 4, 3, 2, 1)
|
||||
],
|
||||
columns=pd.MultiIndex.from_tuples((
|
||||
('a', 1), ('a', 2), ('a', 3),
|
||||
('b', 1), ('b', 2), ('b', 3),
|
||||
))
|
||||
)
|
||||
dm = xgb.DMatrix(df)
|
||||
assert dm.feature_names == ['a 1', 'a 2', 'a 3', 'b 1', 'b 2', 'b 3']
|
||||
assert dm.feature_types == ['int', 'int', 'int', 'int', 'int', 'int']
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 6
|
||||
|
||||
def test_pandas_label(self):
|
||||
# label must be a single column
|
||||
df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user