* Fix #3714: preserve feature names when slicing DMatrix * Add test
This commit is contained in:
parent
813d2436d3
commit
10cd7c8447
@ -351,6 +351,11 @@ class DMatrix(object):
|
|||||||
# force into void_p, mac need to pass things in as void_p
|
# force into void_p, mac need to pass things in as void_p
|
||||||
if data is None:
|
if data is None:
|
||||||
self.handle = None
|
self.handle = None
|
||||||
|
|
||||||
|
if feature_names is not None:
|
||||||
|
self._feature_names = feature_names
|
||||||
|
if feature_types is not None:
|
||||||
|
self._feature_types = feature_types
|
||||||
return
|
return
|
||||||
|
|
||||||
data, feature_names, feature_types = _maybe_pandas_data(data,
|
data, feature_names, feature_types = _maybe_pandas_data(data,
|
||||||
@ -739,7 +744,8 @@ class DMatrix(object):
|
|||||||
res : DMatrix
|
res : DMatrix
|
||||||
A new DMatrix containing only selected indices.
|
A new DMatrix containing only selected indices.
|
||||||
"""
|
"""
|
||||||
res = DMatrix(None, feature_names=self.feature_names)
|
res = DMatrix(None, feature_names=self.feature_names,
|
||||||
|
feature_types=self.feature_types)
|
||||||
res.handle = ctypes.c_void_p()
|
res.handle = ctypes.c_void_p()
|
||||||
_check_call(_LIB.XGDMatrixSliceDMatrix(self.handle,
|
_check_call(_LIB.XGDMatrixSliceDMatrix(self.handle,
|
||||||
c_array(ctypes.c_int, rindex),
|
c_array(ctypes.c_int, rindex),
|
||||||
|
|||||||
@ -118,6 +118,8 @@ class TestBasic(unittest.TestCase):
|
|||||||
dm.feature_names = list('abcde')
|
dm.feature_names = list('abcde')
|
||||||
assert dm.feature_names == list('abcde')
|
assert dm.feature_names == list('abcde')
|
||||||
|
|
||||||
|
assert dm.slice([0, 1]).feature_names == dm.feature_names
|
||||||
|
|
||||||
dm.feature_types = 'q'
|
dm.feature_types = 'q'
|
||||||
assert dm.feature_types == list('qqqqq')
|
assert dm.feature_types == list('qqqqq')
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user