Fix #3714: preserve feature names when slicing DMatrix (#3766)

* Fix #3714: preserve feature names when slicing DMatrix

* Add test
This commit is contained in:
Philip Hyunsu Cho 2018-10-08 01:04:33 -07:00 committed by GitHub
parent 813d2436d3
commit 10cd7c8447
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 1 deletions

View File

@ -351,6 +351,11 @@ class DMatrix(object):
# force into void_p, mac need to pass things in as void_p
if data is None:
self.handle = None
if feature_names is not None:
self._feature_names = feature_names
if feature_types is not None:
self._feature_types = feature_types
return
data, feature_names, feature_types = _maybe_pandas_data(data,
@ -739,7 +744,8 @@ class DMatrix(object):
res : DMatrix
A new DMatrix containing only selected indices.
"""
res = DMatrix(None, feature_names=self.feature_names)
res = DMatrix(None, feature_names=self.feature_names,
feature_types=self.feature_types)
res.handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixSliceDMatrix(self.handle,
c_array(ctypes.c_int, rindex),

View File

@ -118,6 +118,8 @@ class TestBasic(unittest.TestCase):
dm.feature_names = list('abcde')
assert dm.feature_names == list('abcde')
assert dm.slice([0, 1]).feature_names == dm.feature_names
dm.feature_types = 'q'
assert dm.feature_types == list('qqqqq')