diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 30ab6125d..1e955d66c 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -351,6 +351,11 @@ class DMatrix(object): # force into void_p, mac need to pass things in as void_p if data is None: self.handle = None + + if feature_names is not None: + self._feature_names = feature_names + if feature_types is not None: + self._feature_types = feature_types return data, feature_names, feature_types = _maybe_pandas_data(data, @@ -739,7 +744,8 @@ class DMatrix(object): res : DMatrix A new DMatrix containing only selected indices. """ - res = DMatrix(None, feature_names=self.feature_names) + res = DMatrix(None, feature_names=self.feature_names, + feature_types=self.feature_types) res.handle = ctypes.c_void_p() _check_call(_LIB.XGDMatrixSliceDMatrix(self.handle, c_array(ctypes.c_int, rindex), diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index 77336dcff..11e28eeb2 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -118,6 +118,8 @@ class TestBasic(unittest.TestCase): dm.feature_names = list('abcde') assert dm.feature_names == list('abcde') + assert dm.slice([0, 1]).feature_names == dm.feature_names + dm.feature_types = 'q' assert dm.feature_types == list('qqqqq')