diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 775050f63..c82cad227 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -755,35 +755,40 @@ class DMatrix: # pylint: disable=too-many-instance-attributes number of columns : int """ ret = c_bst_ulong() - _check_call(_LIB.XGDMatrixNumCol(self.handle, - ctypes.byref(ret))) + _check_call(_LIB.XGDMatrixNumCol(self.handle, ctypes.byref(ret))) return ret.value - def slice(self, rindex, allow_groups=False): + def slice( + self, rindex: Union[List[int], np.ndarray], allow_groups: bool = False + ) -> "DMatrix": """Slice the DMatrix and return a new DMatrix that only contains `rindex`. Parameters ---------- - rindex : list + rindex List of indices to be selected. - allow_groups : boolean + allow_groups Allow slicing of a matrix with a groups attribute Returns ------- - res : DMatrix + res A new DMatrix containing only selected indices. """ + from .data import _maybe_np_slice + res = DMatrix(None) res.handle = ctypes.c_void_p() - _check_call(_LIB.XGDMatrixSliceDMatrixEx( - self.handle, - c_array(ctypes.c_int, rindex), - c_bst_ulong(len(rindex)), - ctypes.byref(res.handle), - ctypes.c_int(1 if allow_groups else 0))) - res.feature_names = self.feature_names - res.feature_types = self.feature_types + rindex = _maybe_np_slice(rindex, dtype=np.int32) + _check_call( + _LIB.XGDMatrixSliceDMatrixEx( + self.handle, + c_array(ctypes.c_int, rindex), + c_bst_ulong(len(rindex)), + ctypes.byref(res.handle), + ctypes.c_int(1 if allow_groups else 0), + ) + ) return res @property diff --git a/src/data/data.cc b/src/data/data.cc index f203fd3dc..186ec88d5 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -299,6 +299,11 @@ MetaInfo MetaInfo::Slice(common::Span ridxs) const { out.feature_weigths.Resize(this->feature_weigths.Size()); out.feature_weigths.Copy(this->feature_weigths); + + out.feature_names = this->feature_names; + out.feature_types.Copy(this->feature_types); + out.feature_type_names = this->feature_type_names; + return out; } diff --git a/tests/python/test_dmatrix.py b/tests/python/test_dmatrix.py index 85af1451f..7828ac38a 100644 --- a/tests/python/test_dmatrix.py +++ b/tests/python/test_dmatrix.py @@ -145,6 +145,10 @@ class TestDMatrix: num_boost_round=2, evals=[(d2, 'd2'), (sliced, 'sliced')], evals_result=eval_res) np.testing.assert_equal(eval_res['d2']['mlogloss'], eval_res['sliced']['mlogloss']) + ridxs_arr = np.array(ridxs)[1:] # handles numpy slice correctly + sliced = d.slice(ridxs_arr) + np.testing.assert_equal(sliced.get_label(), y[2:7]) + def test_feature_names_slice(self): data = np.random.randn(5, 5)