Fix feature_name crated from int64index dataframe. (#5081)
This commit is contained in:
parent
139ccc9902
commit
018df6004e
@ -81,13 +81,14 @@ else:
|
||||
# pandas
|
||||
try:
|
||||
from pandas import DataFrame, Series
|
||||
from pandas import MultiIndex
|
||||
from pandas import MultiIndex, Int64Index
|
||||
from pandas import concat as pandas_concat
|
||||
|
||||
PANDAS_INSTALLED = True
|
||||
except ImportError:
|
||||
|
||||
MultiIndex = object
|
||||
Int64Index = object
|
||||
DataFrame = object
|
||||
Series = object
|
||||
pandas_concat = None
|
||||
|
||||
@ -18,7 +18,7 @@ import numpy as np
|
||||
import scipy.sparse
|
||||
|
||||
from .compat import (
|
||||
STRING_TYPES, DataFrame, MultiIndex, py_str,
|
||||
STRING_TYPES, DataFrame, MultiIndex, Int64Index, py_str,
|
||||
PANDAS_INSTALLED, DataTable,
|
||||
CUDF_INSTALLED, CUDF_DataFrame, CUDF_Series, CUDF_MultiIndex,
|
||||
os_fspath, os_PathLike)
|
||||
@ -296,6 +296,8 @@ def _maybe_pandas_data(data, feature_names, feature_types):
|
||||
' '.join([str(x) for x in i])
|
||||
for i in data.columns
|
||||
]
|
||||
elif isinstance(data.columns, Int64Index):
|
||||
feature_names = list(map(str, data.columns))
|
||||
else:
|
||||
feature_names = data.columns.format()
|
||||
|
||||
|
||||
@ -83,6 +83,13 @@ class TestPandas(unittest.TestCase):
|
||||
assert dm.num_row() == 3
|
||||
assert dm.num_col() == 2
|
||||
|
||||
df_int = pd.DataFrame([[1, 1.1], [2, 2.2]], columns=[9, 10])
|
||||
dm_int = xgb.DMatrix(df_int)
|
||||
df_range = pd.DataFrame([[1, 1.1], [2, 2.2]], columns=range(9, 11, 1))
|
||||
dm_range = xgb.DMatrix(df_range)
|
||||
assert dm_int.feature_names == ['9', '10'] # assert not "9 "
|
||||
assert dm_int.feature_names == dm_range.feature_names
|
||||
|
||||
# test MultiIndex as columns
|
||||
df = pd.DataFrame(
|
||||
[
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user