Fix issue 2670 (#2671)
* fix issue 2670 * add python<3.6 compatibility * fix Index * fix Index/MultiIndex * fix lint * fix W0622 really nonsense * fix lambda * Trigger Travis * add test for MultiIndex * remove tailing whitespace
This commit is contained in:
parent
ee80f348de
commit
0e85b30fdd
@ -184,7 +184,13 @@ Did not expect the data types in fields """
|
|||||||
raise ValueError(msg + ', '.join(bad_fields))
|
raise ValueError(msg + ', '.join(bad_fields))
|
||||||
|
|
||||||
if feature_names is None:
|
if feature_names is None:
|
||||||
feature_names = data.columns.format()
|
if hasattr(data.columns, 'to_frame'): # MultiIndex
|
||||||
|
feature_names = [
|
||||||
|
' '.join(map(str, i))
|
||||||
|
for i in data.columns
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
feature_names = data.columns.format()
|
||||||
|
|
||||||
if feature_types is None:
|
if feature_types is None:
|
||||||
feature_types = [PANDAS_DTYPE_MAPPER[dtype.name] for dtype in data_dtypes]
|
feature_types = [PANDAS_DTYPE_MAPPER[dtype.name] for dtype in data_dtypes]
|
||||||
|
|||||||
@ -79,6 +79,23 @@ class TestPandas(unittest.TestCase):
|
|||||||
assert dm.num_row() == 3
|
assert dm.num_row() == 3
|
||||||
assert dm.num_col() == 2
|
assert dm.num_col() == 2
|
||||||
|
|
||||||
|
# test MultiIndex as columns
|
||||||
|
df = pd.DataFrame(
|
||||||
|
[
|
||||||
|
(1, 2, 3, 4, 5, 6),
|
||||||
|
(6, 5, 4, 3, 2, 1)
|
||||||
|
],
|
||||||
|
columns=pd.MultiIndex.from_tuples((
|
||||||
|
('a', 1), ('a', 2), ('a', 3),
|
||||||
|
('b', 1), ('b', 2), ('b', 3),
|
||||||
|
))
|
||||||
|
)
|
||||||
|
dm = xgb.DMatrix(df)
|
||||||
|
assert dm.feature_names == ['a 1', 'a 2', 'a 3', 'b 1', 'b 2', 'b 3']
|
||||||
|
assert dm.feature_types == ['int', 'int', 'int', 'int', 'int', 'int']
|
||||||
|
assert dm.num_row() == 2
|
||||||
|
assert dm.num_col() == 6
|
||||||
|
|
||||||
def test_pandas_label(self):
|
def test_pandas_label(self):
|
||||||
# label must be a single column
|
# label must be a single column
|
||||||
df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
|
df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user