Fix MultiIndex detection (breaks for latest pandas==0.21.0). (#2872)

This commit is contained in:
Rory Mitchell 2017-11-11 11:12:23 +13:00 committed by GitHub
parent 77ae4c8701
commit 16c63f30d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 2 deletions

View File

@ -33,9 +33,14 @@ except ImportError:
# pandas # pandas
try: try:
from pandas import DataFrame from pandas import DataFrame
from pandas import MultiIndex
PANDAS_INSTALLED = True PANDAS_INSTALLED = True
except ImportError: except ImportError:
class MultiIndex(object):
""" dummy for pandas.MultiIndex """
pass
class DataFrame(object): class DataFrame(object):
""" dummy for pandas.DataFrame """ """ dummy for pandas.DataFrame """
pass pass

View File

@ -15,7 +15,7 @@ import scipy.sparse
from .libpath import find_lib_path from .libpath import find_lib_path
from .compat import STRING_TYPES, PY3, DataFrame, py_str, PANDAS_INSTALLED from .compat import STRING_TYPES, PY3, DataFrame, MultiIndex, py_str, PANDAS_INSTALLED
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h # c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
c_bst_ulong = ctypes.c_uint64 c_bst_ulong = ctypes.c_uint64
@ -184,7 +184,7 @@ Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields)) raise ValueError(msg + ', '.join(bad_fields))
if feature_names is None: if feature_names is None:
if hasattr(data.columns, 'to_frame'): # MultiIndex if isinstance(data.columns, MultiIndex):
feature_names = [ feature_names = [
' '.join(map(str, i)) ' '.join(map(str, i))
for i in data.columns for i in data.columns