Accept numpy array for DMatrix slice index. (#6368)

This commit is contained in:
Jiaming Yuan
2020-12-16 14:42:52 +08:00
committed by GitHub
parent ef4a0e0aac
commit 347f593169
3 changed files with 28 additions and 14 deletions

View File

@@ -755,35 +755,40 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
number of columns : int
"""
ret = c_bst_ulong()
_check_call(_LIB.XGDMatrixNumCol(self.handle,
ctypes.byref(ret)))
_check_call(_LIB.XGDMatrixNumCol(self.handle, ctypes.byref(ret)))
return ret.value
def slice(self, rindex, allow_groups=False):
def slice(
self, rindex: Union[List[int], np.ndarray], allow_groups: bool = False
) -> "DMatrix":
"""Slice the DMatrix and return a new DMatrix that only contains `rindex`.
Parameters
----------
rindex : list
rindex
List of indices to be selected.
allow_groups : boolean
allow_groups
Allow slicing of a matrix with a groups attribute
Returns
-------
res : DMatrix
res
A new DMatrix containing only selected indices.
"""
from .data import _maybe_np_slice
res = DMatrix(None)
res.handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixSliceDMatrixEx(
self.handle,
c_array(ctypes.c_int, rindex),
c_bst_ulong(len(rindex)),
ctypes.byref(res.handle),
ctypes.c_int(1 if allow_groups else 0)))
res.feature_names = self.feature_names
res.feature_types = self.feature_types
rindex = _maybe_np_slice(rindex, dtype=np.int32)
_check_call(
_LIB.XGDMatrixSliceDMatrixEx(
self.handle,
c_array(ctypes.c_int, rindex),
c_bst_ulong(len(rindex)),
ctypes.byref(res.handle),
ctypes.c_int(1 if allow_groups else 0),
)
)
return res
@property