Support latest pandas Index type. (#7595)
This commit is contained in:
parent
511805c981
commit
24789429fd
@ -33,14 +33,13 @@ def lazy_isinstance(instance, module, name):
|
||||
# pandas
|
||||
try:
|
||||
from pandas import DataFrame, Series
|
||||
from pandas import MultiIndex, Int64Index
|
||||
from pandas import MultiIndex
|
||||
from pandas import concat as pandas_concat
|
||||
|
||||
PANDAS_INSTALLED = True
|
||||
except ImportError:
|
||||
|
||||
MultiIndex = object
|
||||
Int64Index = object
|
||||
DataFrame: Any = object
|
||||
Series = object
|
||||
pandas_concat = None
|
||||
|
||||
@ -264,7 +264,7 @@ def _transform_pandas_df(
|
||||
if feature_names is None and meta is None:
|
||||
if isinstance(data.columns, pd.MultiIndex):
|
||||
feature_names = [" ".join([str(x) for x in i]) for i in data.columns]
|
||||
elif isinstance(data.columns, (pd.Int64Index, pd.RangeIndex)):
|
||||
elif isinstance(data.columns, (pd.Index, pd.RangeIndex)):
|
||||
feature_names = list(map(str, data.columns))
|
||||
else:
|
||||
feature_names = data.columns.format()
|
||||
|
||||
@ -109,6 +109,12 @@ class TestPandas:
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 6
|
||||
|
||||
# test Index as columns
|
||||
df = pd.DataFrame([[1, 1.1], [2, 2.2]], columns=pd.Index([1, 2]))
|
||||
print(df.columns, isinstance(df.columns, pd.Index))
|
||||
Xy = xgb.DMatrix(df)
|
||||
np.testing.assert_equal(np.array(Xy.feature_names), np.array(["1", "2"]))
|
||||
|
||||
def test_slice(self):
|
||||
rng = np.random.RandomState(1994)
|
||||
rows = 100
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user