Merge pull request #522 from sinhrks/pandas
python DMatrix now accepts pandas DataFrame
This commit is contained in:
commit
2859c190cd
@ -138,6 +138,28 @@ def c_array(ctype, values):
|
||||
return (ctype * len(values))(*values)
|
||||
|
||||
|
||||
def _maybe_from_pandas(data, feature_names, feature_types):
|
||||
""" Extract internal data from pd.DataFrame """
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
return data, feature_names, feature_types
|
||||
|
||||
if not isinstance(data, pd.DataFrame):
|
||||
return data, feature_names, feature_types
|
||||
|
||||
dtypes = data.dtypes
|
||||
if not all(dtype.name in ('int64', 'float64', 'bool') for dtype in dtypes):
|
||||
raise ValueError('DataFrame.dtypes must be int, float or bool')
|
||||
|
||||
if feature_names is None:
|
||||
feature_names = data.columns.tolist()
|
||||
if feature_types is None:
|
||||
mapper = {'int64': 'int', 'float64': 'q', 'bool': 'i'}
|
||||
feature_types = [mapper[dtype.name] for dtype in dtypes]
|
||||
data = data.values.astype('float')
|
||||
return data, feature_names, feature_types
|
||||
|
||||
class DMatrix(object):
|
||||
"""Data Matrix used in XGBoost.
|
||||
|
||||
@ -157,7 +179,7 @@ class DMatrix(object):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : string/numpy array/scipy.sparse
|
||||
data : string/numpy array/scipy.sparse/pd.DataFrame
|
||||
Data source of DMatrix.
|
||||
When data is string type, it represents the path libsvm format txt file,
|
||||
or binary file that xgboost can read from.
|
||||
@ -178,6 +200,13 @@ class DMatrix(object):
|
||||
if data is None:
|
||||
self.handle = None
|
||||
return
|
||||
|
||||
klass = getattr(getattr(data, '__class__', None), '__name__', None)
|
||||
if klass == 'DataFrame':
|
||||
# once check class name to avoid unnecessary pandas import
|
||||
data, feature_names, feature_types = _maybe_from_pandas(data, feature_names,
|
||||
feature_types)
|
||||
|
||||
if isinstance(data, STRING_TYPES):
|
||||
self.handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(data),
|
||||
|
||||
@ -97,6 +97,27 @@ class TestBasic(unittest.TestCase):
|
||||
dm = xgb.DMatrix(dummy, feature_names=list('abcde'))
|
||||
self.assertRaises(ValueError, bst.predict, dm)
|
||||
|
||||
def test_pandas(self):
|
||||
import pandas as pd
|
||||
df = pd.DataFrame([[1, 2., True], [2, 3., False]], columns=['a', 'b', 'c'])
|
||||
dm = xgb.DMatrix(df, label=pd.Series([1, 2]))
|
||||
assert dm.feature_names == ['a', 'b', 'c']
|
||||
assert dm.feature_types == ['int', 'q', 'i']
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 3
|
||||
|
||||
# overwrite feature_names and feature_types
|
||||
dm = xgb.DMatrix(df, label=pd.Series([1, 2]),
|
||||
feature_names=['x', 'y', 'z'], feature_types=['q', 'q', 'q'])
|
||||
assert dm.feature_names == ['x', 'y', 'z']
|
||||
assert dm.feature_types == ['q', 'q', 'q']
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 3
|
||||
|
||||
# incorrect dtypes
|
||||
df = pd.DataFrame([[1, 2., 'x'], [2, 3., 'y']], columns=['a', 'b', 'c'])
|
||||
self.assertRaises(ValueError, xgb.DMatrix, df)
|
||||
|
||||
def test_load_file_invalid(self):
|
||||
|
||||
self.assertRaises(ValueError, xgb.Booster,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user