Fix feature_name crated from int64index dataframe. (#5081)
This commit is contained in:
parent
139ccc9902
commit
018df6004e
@ -81,13 +81,14 @@ else:
|
|||||||
# pandas
|
# pandas
|
||||||
try:
|
try:
|
||||||
from pandas import DataFrame, Series
|
from pandas import DataFrame, Series
|
||||||
from pandas import MultiIndex
|
from pandas import MultiIndex, Int64Index
|
||||||
from pandas import concat as pandas_concat
|
from pandas import concat as pandas_concat
|
||||||
|
|
||||||
PANDAS_INSTALLED = True
|
PANDAS_INSTALLED = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
||||||
MultiIndex = object
|
MultiIndex = object
|
||||||
|
Int64Index = object
|
||||||
DataFrame = object
|
DataFrame = object
|
||||||
Series = object
|
Series = object
|
||||||
pandas_concat = None
|
pandas_concat = None
|
||||||
|
|||||||
@ -18,7 +18,7 @@ import numpy as np
|
|||||||
import scipy.sparse
|
import scipy.sparse
|
||||||
|
|
||||||
from .compat import (
|
from .compat import (
|
||||||
STRING_TYPES, DataFrame, MultiIndex, py_str,
|
STRING_TYPES, DataFrame, MultiIndex, Int64Index, py_str,
|
||||||
PANDAS_INSTALLED, DataTable,
|
PANDAS_INSTALLED, DataTable,
|
||||||
CUDF_INSTALLED, CUDF_DataFrame, CUDF_Series, CUDF_MultiIndex,
|
CUDF_INSTALLED, CUDF_DataFrame, CUDF_Series, CUDF_MultiIndex,
|
||||||
os_fspath, os_PathLike)
|
os_fspath, os_PathLike)
|
||||||
@ -296,6 +296,8 @@ def _maybe_pandas_data(data, feature_names, feature_types):
|
|||||||
' '.join([str(x) for x in i])
|
' '.join([str(x) for x in i])
|
||||||
for i in data.columns
|
for i in data.columns
|
||||||
]
|
]
|
||||||
|
elif isinstance(data.columns, Int64Index):
|
||||||
|
feature_names = list(map(str, data.columns))
|
||||||
else:
|
else:
|
||||||
feature_names = data.columns.format()
|
feature_names = data.columns.format()
|
||||||
|
|
||||||
|
|||||||
@ -83,6 +83,13 @@ class TestPandas(unittest.TestCase):
|
|||||||
assert dm.num_row() == 3
|
assert dm.num_row() == 3
|
||||||
assert dm.num_col() == 2
|
assert dm.num_col() == 2
|
||||||
|
|
||||||
|
df_int = pd.DataFrame([[1, 1.1], [2, 2.2]], columns=[9, 10])
|
||||||
|
dm_int = xgb.DMatrix(df_int)
|
||||||
|
df_range = pd.DataFrame([[1, 1.1], [2, 2.2]], columns=range(9, 11, 1))
|
||||||
|
dm_range = xgb.DMatrix(df_range)
|
||||||
|
assert dm_int.feature_names == ['9', '10'] # assert not "9 "
|
||||||
|
assert dm_int.feature_names == dm_range.feature_names
|
||||||
|
|
||||||
# test MultiIndex as columns
|
# test MultiIndex as columns
|
||||||
df = pd.DataFrame(
|
df = pd.DataFrame(
|
||||||
[
|
[
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user