Accept numpy array for DMatrix slice index. (#6368)

This commit is contained in:
Jiaming Yuan 2020-12-16 14:42:52 +08:00 committed by GitHub
parent ef4a0e0aac
commit 347f593169
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 14 deletions

View File

@ -755,35 +755,40 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
number of columns : int
"""
ret = c_bst_ulong()
_check_call(_LIB.XGDMatrixNumCol(self.handle,
ctypes.byref(ret)))
_check_call(_LIB.XGDMatrixNumCol(self.handle, ctypes.byref(ret)))
return ret.value
def slice(self, rindex, allow_groups=False):
def slice(
self, rindex: Union[List[int], np.ndarray], allow_groups: bool = False
) -> "DMatrix":
"""Slice the DMatrix and return a new DMatrix that only contains `rindex`.
Parameters
----------
rindex : list
rindex
List of indices to be selected.
allow_groups : boolean
allow_groups
Allow slicing of a matrix with a groups attribute
Returns
-------
res : DMatrix
res
A new DMatrix containing only selected indices.
"""
from .data import _maybe_np_slice
res = DMatrix(None)
res.handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixSliceDMatrixEx(
self.handle,
c_array(ctypes.c_int, rindex),
c_bst_ulong(len(rindex)),
ctypes.byref(res.handle),
ctypes.c_int(1 if allow_groups else 0)))
res.feature_names = self.feature_names
res.feature_types = self.feature_types
rindex = _maybe_np_slice(rindex, dtype=np.int32)
_check_call(
_LIB.XGDMatrixSliceDMatrixEx(
self.handle,
c_array(ctypes.c_int, rindex),
c_bst_ulong(len(rindex)),
ctypes.byref(res.handle),
ctypes.c_int(1 if allow_groups else 0),
)
)
return res
@property

View File

@ -299,6 +299,11 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
out.feature_weigths.Resize(this->feature_weigths.Size());
out.feature_weigths.Copy(this->feature_weigths);
out.feature_names = this->feature_names;
out.feature_types.Copy(this->feature_types);
out.feature_type_names = this->feature_type_names;
return out;
}

View File

@ -145,6 +145,10 @@ class TestDMatrix:
num_boost_round=2, evals=[(d2, 'd2'), (sliced, 'sliced')], evals_result=eval_res)
np.testing.assert_equal(eval_res['d2']['mlogloss'], eval_res['sliced']['mlogloss'])
ridxs_arr = np.array(ridxs)[1:] # handles numpy slice correctly
sliced = d.slice(ridxs_arr)
np.testing.assert_equal(sliced.get_label(), y[2:7])
def test_feature_names_slice(self):
data = np.random.randn(5, 5)