Fix feature_name crated from int64index dataframe. (#5081)

This commit is contained in:
K.O 2019-12-30 13:26:22 +09:00 committed by Jiaming Yuan
parent 139ccc9902
commit 018df6004e
3 changed files with 12 additions and 2 deletions

View File

@ -81,13 +81,14 @@ else:
# pandas
try:
from pandas import DataFrame, Series
from pandas import MultiIndex
from pandas import MultiIndex, Int64Index
from pandas import concat as pandas_concat
PANDAS_INSTALLED = True
except ImportError:
MultiIndex = object
Int64Index = object
DataFrame = object
Series = object
pandas_concat = None

View File

@ -18,7 +18,7 @@ import numpy as np
import scipy.sparse
from .compat import (
STRING_TYPES, DataFrame, MultiIndex, py_str,
STRING_TYPES, DataFrame, MultiIndex, Int64Index, py_str,
PANDAS_INSTALLED, DataTable,
CUDF_INSTALLED, CUDF_DataFrame, CUDF_Series, CUDF_MultiIndex,
os_fspath, os_PathLike)
@ -296,6 +296,8 @@ def _maybe_pandas_data(data, feature_names, feature_types):
' '.join([str(x) for x in i])
for i in data.columns
]
elif isinstance(data.columns, Int64Index):
feature_names = list(map(str, data.columns))
else:
feature_names = data.columns.format()

View File

@ -83,6 +83,13 @@ class TestPandas(unittest.TestCase):
assert dm.num_row() == 3
assert dm.num_col() == 2
df_int = pd.DataFrame([[1, 1.1], [2, 2.2]], columns=[9, 10])
dm_int = xgb.DMatrix(df_int)
df_range = pd.DataFrame([[1, 1.1], [2, 2.2]], columns=range(9, 11, 1))
dm_range = xgb.DMatrix(df_range)
assert dm_int.feature_names == ['9', '10'] # assert not "9 "
assert dm_int.feature_names == dm_range.feature_names
# test MultiIndex as columns
df = pd.DataFrame(
[