python DMatrix now accepts pandas DataFrame

This commit is contained in:
sinhrks 2015-10-01 22:39:56 +09:00
parent db490d1c75
commit b943becc61
3 changed files with 52 additions and 2 deletions

View File

@ -138,6 +138,28 @@ def c_array(ctype, values):
return (ctype * len(values))(*values) return (ctype * len(values))(*values)
def _maybe_from_pandas(data, feature_names, feature_types):
""" Extract internal data from pd.DataFrame """
try:
import pandas as pd
except ImportError:
return data, feature_names, feature_types
if not isinstance(data, pd.DataFrame):
return data, feature_names, feature_types
dtypes = data.dtypes
if not all(dtype.name in ('int64', 'float64', 'bool') for dtype in dtypes):
raise ValueError('DataFrame.dtypes must be int, float or bool')
if feature_names is None:
feature_names = data.columns.tolist()
if feature_types is None:
mapper = {'int64': 'int', 'float64': 'q', 'bool': 'i'}
feature_types = [mapper[dtype.name] for dtype in dtypes]
data = data.values.astype('float')
return data, feature_names, feature_types
class DMatrix(object): class DMatrix(object):
"""Data Matrix used in XGBoost. """Data Matrix used in XGBoost.
@ -157,7 +179,7 @@ class DMatrix(object):
Parameters Parameters
---------- ----------
data : string/numpy array/scipy.sparse data : string/numpy array/scipy.sparse/pd.DataFrame
Data source of DMatrix. Data source of DMatrix.
When data is string type, it represents the path libsvm format txt file, When data is string type, it represents the path libsvm format txt file,
or binary file that xgboost can read from. or binary file that xgboost can read from.
@ -178,6 +200,13 @@ class DMatrix(object):
if data is None: if data is None:
self.handle = None self.handle = None
return return
klass = getattr(getattr(data, '__class__', None), '__name__', None)
if klass == 'DataFrame':
# once check class name to avoid unnecessary pandas import
data, feature_names, feature_types = _maybe_from_pandas(data, feature_names,
feature_types)
if isinstance(data, STRING_TYPES): if isinstance(data, STRING_TYPES):
self.handle = ctypes.c_void_p() self.handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(data), _check_call(_LIB.XGDMatrixCreateFromFile(c_str(data),

View File

@ -64,7 +64,7 @@ if [ ${TASK} == "python-package" -o ${TASK} == "python-package3" ]; then
conda create -n myenv python=2.7 conda create -n myenv python=2.7
fi fi
source activate myenv source activate myenv
conda install numpy scipy matplotlib nose conda install numpy scipy pandas matplotlib nose
python -m pip install graphviz python -m pip install graphviz
make all CXX=${CXX} || exit -1 make all CXX=${CXX} || exit -1

View File

@ -97,6 +97,27 @@ class TestBasic(unittest.TestCase):
dm = xgb.DMatrix(dummy, feature_names=list('abcde')) dm = xgb.DMatrix(dummy, feature_names=list('abcde'))
self.assertRaises(ValueError, bst.predict, dm) self.assertRaises(ValueError, bst.predict, dm)
def test_pandas(self):
import pandas as pd
df = pd.DataFrame([[1, 2., True], [2, 3., False]], columns=['a', 'b', 'c'])
dm = xgb.DMatrix(df, label=pd.Series([1, 2]))
assert dm.feature_names == ['a', 'b', 'c']
assert dm.feature_types == ['int', 'q', 'i']
assert dm.num_row() == 2
assert dm.num_col() == 3
# overwrite feature_names and feature_types
dm = xgb.DMatrix(df, label=pd.Series([1, 2]),
feature_names=['x', 'y', 'z'], feature_types=['q', 'q', 'q'])
assert dm.feature_names == ['x', 'y', 'z']
assert dm.feature_types == ['q', 'q', 'q']
assert dm.num_row() == 2
assert dm.num_col() == 3
# incorrect dtypes
df = pd.DataFrame([[1, 2., 'x'], [2, 3., 'y']], columns=['a', 'b', 'c'])
self.assertRaises(ValueError, xgb.DMatrix, df)
def test_load_file_invalid(self): def test_load_file_invalid(self):
self.assertRaises(ValueError, xgb.Booster, self.assertRaises(ValueError, xgb.Booster,