Merge pull request #522 from sinhrks/pandas

python DMatrix now accepts pandas DataFrame
This commit is contained in:
Tianqi Chen
2015-10-02 10:19:14 -07:00
2 changed files with 51 additions and 1 deletions

View File

@@ -138,6 +138,28 @@ def c_array(ctype, values):
return (ctype * len(values))(*values)
def _maybe_from_pandas(data, feature_names, feature_types):
""" Extract internal data from pd.DataFrame """
try:
import pandas as pd
except ImportError:
return data, feature_names, feature_types
if not isinstance(data, pd.DataFrame):
return data, feature_names, feature_types
dtypes = data.dtypes
if not all(dtype.name in ('int64', 'float64', 'bool') for dtype in dtypes):
raise ValueError('DataFrame.dtypes must be int, float or bool')
if feature_names is None:
feature_names = data.columns.tolist()
if feature_types is None:
mapper = {'int64': 'int', 'float64': 'q', 'bool': 'i'}
feature_types = [mapper[dtype.name] for dtype in dtypes]
data = data.values.astype('float')
return data, feature_names, feature_types
class DMatrix(object):
"""Data Matrix used in XGBoost.
@@ -157,7 +179,7 @@ class DMatrix(object):
Parameters
----------
data : string/numpy array/scipy.sparse
data : string/numpy array/scipy.sparse/pd.DataFrame
Data source of DMatrix.
When data is string type, it represents the path libsvm format txt file,
or binary file that xgboost can read from.
@@ -178,6 +200,13 @@ class DMatrix(object):
if data is None:
self.handle = None
return
klass = getattr(getattr(data, '__class__', None), '__name__', None)
if klass == 'DataFrame':
# once check class name to avoid unnecessary pandas import
data, feature_names, feature_types = _maybe_from_pandas(data, feature_names,
feature_types)
if isinstance(data, STRING_TYPES):
self.handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(data),