* Fix #3714: preserve feature names when slicing DMatrix * Add test
This commit is contained in:
parent
813d2436d3
commit
10cd7c8447
@ -351,6 +351,11 @@ class DMatrix(object):
|
||||
# force into void_p, mac need to pass things in as void_p
|
||||
if data is None:
|
||||
self.handle = None
|
||||
|
||||
if feature_names is not None:
|
||||
self._feature_names = feature_names
|
||||
if feature_types is not None:
|
||||
self._feature_types = feature_types
|
||||
return
|
||||
|
||||
data, feature_names, feature_types = _maybe_pandas_data(data,
|
||||
@ -739,7 +744,8 @@ class DMatrix(object):
|
||||
res : DMatrix
|
||||
A new DMatrix containing only selected indices.
|
||||
"""
|
||||
res = DMatrix(None, feature_names=self.feature_names)
|
||||
res = DMatrix(None, feature_names=self.feature_names,
|
||||
feature_types=self.feature_types)
|
||||
res.handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGDMatrixSliceDMatrix(self.handle,
|
||||
c_array(ctypes.c_int, rindex),
|
||||
|
||||
@ -118,6 +118,8 @@ class TestBasic(unittest.TestCase):
|
||||
dm.feature_names = list('abcde')
|
||||
assert dm.feature_names == list('abcde')
|
||||
|
||||
assert dm.slice([0, 1]).feature_names == dm.feature_names
|
||||
|
||||
dm.feature_types = 'q'
|
||||
assert dm.feature_types == list('qqqqq')
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user