Merge pull request #522 from sinhrks/pandas

python DMatrix now accepts pandas DataFrame
This commit is contained in:
Tianqi Chen
2015-10-02 10:19:14 -07:00
2 changed files with 51 additions and 1 deletions

View File

@@ -97,6 +97,27 @@ class TestBasic(unittest.TestCase):
dm = xgb.DMatrix(dummy, feature_names=list('abcde'))
self.assertRaises(ValueError, bst.predict, dm)
def test_pandas(self):
import pandas as pd
df = pd.DataFrame([[1, 2., True], [2, 3., False]], columns=['a', 'b', 'c'])
dm = xgb.DMatrix(df, label=pd.Series([1, 2]))
assert dm.feature_names == ['a', 'b', 'c']
assert dm.feature_types == ['int', 'q', 'i']
assert dm.num_row() == 2
assert dm.num_col() == 3
# overwrite feature_names and feature_types
dm = xgb.DMatrix(df, label=pd.Series([1, 2]),
feature_names=['x', 'y', 'z'], feature_types=['q', 'q', 'q'])
assert dm.feature_names == ['x', 'y', 'z']
assert dm.feature_types == ['q', 'q', 'q']
assert dm.num_row() == 2
assert dm.num_col() == 3
# incorrect dtypes
df = pd.DataFrame([[1, 2., 'x'], [2, 3., 'y']], columns=['a', 'b', 'c'])
self.assertRaises(ValueError, xgb.DMatrix, df)
def test_load_file_invalid(self):
self.assertRaises(ValueError, xgb.Booster,