Accept numpy array for DMatrix slice index. (#6368)

This commit is contained in:
Jiaming Yuan 2020-12-16 14:42:52 +08:00 committed by GitHub
parent ef4a0e0aac
commit 347f593169
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 14 deletions

View File

@ -755,35 +755,40 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
number of columns : int number of columns : int
""" """
ret = c_bst_ulong() ret = c_bst_ulong()
_check_call(_LIB.XGDMatrixNumCol(self.handle, _check_call(_LIB.XGDMatrixNumCol(self.handle, ctypes.byref(ret)))
ctypes.byref(ret)))
return ret.value return ret.value
def slice(self, rindex, allow_groups=False): def slice(
self, rindex: Union[List[int], np.ndarray], allow_groups: bool = False
) -> "DMatrix":
"""Slice the DMatrix and return a new DMatrix that only contains `rindex`. """Slice the DMatrix and return a new DMatrix that only contains `rindex`.
Parameters Parameters
---------- ----------
rindex : list rindex
List of indices to be selected. List of indices to be selected.
allow_groups : boolean allow_groups
Allow slicing of a matrix with a groups attribute Allow slicing of a matrix with a groups attribute
Returns Returns
------- -------
res : DMatrix res
A new DMatrix containing only selected indices. A new DMatrix containing only selected indices.
""" """
from .data import _maybe_np_slice
res = DMatrix(None) res = DMatrix(None)
res.handle = ctypes.c_void_p() res.handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixSliceDMatrixEx( rindex = _maybe_np_slice(rindex, dtype=np.int32)
self.handle, _check_call(
c_array(ctypes.c_int, rindex), _LIB.XGDMatrixSliceDMatrixEx(
c_bst_ulong(len(rindex)), self.handle,
ctypes.byref(res.handle), c_array(ctypes.c_int, rindex),
ctypes.c_int(1 if allow_groups else 0))) c_bst_ulong(len(rindex)),
res.feature_names = self.feature_names ctypes.byref(res.handle),
res.feature_types = self.feature_types ctypes.c_int(1 if allow_groups else 0),
)
)
return res return res
@property @property

View File

@ -299,6 +299,11 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
out.feature_weigths.Resize(this->feature_weigths.Size()); out.feature_weigths.Resize(this->feature_weigths.Size());
out.feature_weigths.Copy(this->feature_weigths); out.feature_weigths.Copy(this->feature_weigths);
out.feature_names = this->feature_names;
out.feature_types.Copy(this->feature_types);
out.feature_type_names = this->feature_type_names;
return out; return out;
} }

View File

@ -145,6 +145,10 @@ class TestDMatrix:
num_boost_round=2, evals=[(d2, 'd2'), (sliced, 'sliced')], evals_result=eval_res) num_boost_round=2, evals=[(d2, 'd2'), (sliced, 'sliced')], evals_result=eval_res)
np.testing.assert_equal(eval_res['d2']['mlogloss'], eval_res['sliced']['mlogloss']) np.testing.assert_equal(eval_res['d2']['mlogloss'], eval_res['sliced']['mlogloss'])
ridxs_arr = np.array(ridxs)[1:] # handles numpy slice correctly
sliced = d.slice(ridxs_arr)
np.testing.assert_equal(sliced.get_label(), y[2:7])
def test_feature_names_slice(self): def test_feature_names_slice(self):
data = np.random.randn(5, 5) data = np.random.randn(5, 5)