Merge pull request #522 from sinhrks/pandas

python DMatrix now accepts pandas DataFrame
This commit is contained in:
Tianqi Chen 2015-10-02 10:19:14 -07:00
commit 2859c190cd
2 changed files with 51 additions and 1 deletions

View File

@ -138,6 +138,28 @@ def c_array(ctype, values):
return (ctype * len(values))(*values) return (ctype * len(values))(*values)
def _maybe_from_pandas(data, feature_names, feature_types):
""" Extract internal data from pd.DataFrame """
try:
import pandas as pd
except ImportError:
return data, feature_names, feature_types
if not isinstance(data, pd.DataFrame):
return data, feature_names, feature_types
dtypes = data.dtypes
if not all(dtype.name in ('int64', 'float64', 'bool') for dtype in dtypes):
raise ValueError('DataFrame.dtypes must be int, float or bool')
if feature_names is None:
feature_names = data.columns.tolist()
if feature_types is None:
mapper = {'int64': 'int', 'float64': 'q', 'bool': 'i'}
feature_types = [mapper[dtype.name] for dtype in dtypes]
data = data.values.astype('float')
return data, feature_names, feature_types
class DMatrix(object): class DMatrix(object):
"""Data Matrix used in XGBoost. """Data Matrix used in XGBoost.
@ -157,7 +179,7 @@ class DMatrix(object):
Parameters Parameters
---------- ----------
data : string/numpy array/scipy.sparse data : string/numpy array/scipy.sparse/pd.DataFrame
Data source of DMatrix. Data source of DMatrix.
When data is string type, it represents the path libsvm format txt file, When data is string type, it represents the path libsvm format txt file,
or binary file that xgboost can read from. or binary file that xgboost can read from.
@ -178,6 +200,13 @@ class DMatrix(object):
if data is None: if data is None:
self.handle = None self.handle = None
return return
klass = getattr(getattr(data, '__class__', None), '__name__', None)
if klass == 'DataFrame':
# once check class name to avoid unnecessary pandas import
data, feature_names, feature_types = _maybe_from_pandas(data, feature_names,
feature_types)
if isinstance(data, STRING_TYPES): if isinstance(data, STRING_TYPES):
self.handle = ctypes.c_void_p() self.handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(data), _check_call(_LIB.XGDMatrixCreateFromFile(c_str(data),

View File

@ -97,6 +97,27 @@ class TestBasic(unittest.TestCase):
dm = xgb.DMatrix(dummy, feature_names=list('abcde')) dm = xgb.DMatrix(dummy, feature_names=list('abcde'))
self.assertRaises(ValueError, bst.predict, dm) self.assertRaises(ValueError, bst.predict, dm)
def test_pandas(self):
import pandas as pd
df = pd.DataFrame([[1, 2., True], [2, 3., False]], columns=['a', 'b', 'c'])
dm = xgb.DMatrix(df, label=pd.Series([1, 2]))
assert dm.feature_names == ['a', 'b', 'c']
assert dm.feature_types == ['int', 'q', 'i']
assert dm.num_row() == 2
assert dm.num_col() == 3
# overwrite feature_names and feature_types
dm = xgb.DMatrix(df, label=pd.Series([1, 2]),
feature_names=['x', 'y', 'z'], feature_types=['q', 'q', 'q'])
assert dm.feature_names == ['x', 'y', 'z']
assert dm.feature_types == ['q', 'q', 'q']
assert dm.num_row() == 2
assert dm.num_col() == 3
# incorrect dtypes
df = pd.DataFrame([[1, 2., 'x'], [2, 3., 'y']], columns=['a', 'b', 'c'])
self.assertRaises(ValueError, xgb.DMatrix, df)
def test_load_file_invalid(self): def test_load_file_invalid(self):
self.assertRaises(ValueError, xgb.Booster, self.assertRaises(ValueError, xgb.Booster,