Accept numpy array for DMatrix slice index. (#6368)
This commit is contained in:
parent
ef4a0e0aac
commit
347f593169
@ -755,35 +755,40 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
number of columns : int
|
||||
"""
|
||||
ret = c_bst_ulong()
|
||||
_check_call(_LIB.XGDMatrixNumCol(self.handle,
|
||||
ctypes.byref(ret)))
|
||||
_check_call(_LIB.XGDMatrixNumCol(self.handle, ctypes.byref(ret)))
|
||||
return ret.value
|
||||
|
||||
def slice(self, rindex, allow_groups=False):
|
||||
def slice(
|
||||
self, rindex: Union[List[int], np.ndarray], allow_groups: bool = False
|
||||
) -> "DMatrix":
|
||||
"""Slice the DMatrix and return a new DMatrix that only contains `rindex`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rindex : list
|
||||
rindex
|
||||
List of indices to be selected.
|
||||
allow_groups : boolean
|
||||
allow_groups
|
||||
Allow slicing of a matrix with a groups attribute
|
||||
|
||||
Returns
|
||||
-------
|
||||
res : DMatrix
|
||||
res
|
||||
A new DMatrix containing only selected indices.
|
||||
"""
|
||||
from .data import _maybe_np_slice
|
||||
|
||||
res = DMatrix(None)
|
||||
res.handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGDMatrixSliceDMatrixEx(
|
||||
self.handle,
|
||||
c_array(ctypes.c_int, rindex),
|
||||
c_bst_ulong(len(rindex)),
|
||||
ctypes.byref(res.handle),
|
||||
ctypes.c_int(1 if allow_groups else 0)))
|
||||
res.feature_names = self.feature_names
|
||||
res.feature_types = self.feature_types
|
||||
rindex = _maybe_np_slice(rindex, dtype=np.int32)
|
||||
_check_call(
|
||||
_LIB.XGDMatrixSliceDMatrixEx(
|
||||
self.handle,
|
||||
c_array(ctypes.c_int, rindex),
|
||||
c_bst_ulong(len(rindex)),
|
||||
ctypes.byref(res.handle),
|
||||
ctypes.c_int(1 if allow_groups else 0),
|
||||
)
|
||||
)
|
||||
return res
|
||||
|
||||
@property
|
||||
|
||||
@ -299,6 +299,11 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
||||
|
||||
out.feature_weigths.Resize(this->feature_weigths.Size());
|
||||
out.feature_weigths.Copy(this->feature_weigths);
|
||||
|
||||
out.feature_names = this->feature_names;
|
||||
out.feature_types.Copy(this->feature_types);
|
||||
out.feature_type_names = this->feature_type_names;
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -145,6 +145,10 @@ class TestDMatrix:
|
||||
num_boost_round=2, evals=[(d2, 'd2'), (sliced, 'sliced')], evals_result=eval_res)
|
||||
np.testing.assert_equal(eval_res['d2']['mlogloss'], eval_res['sliced']['mlogloss'])
|
||||
|
||||
ridxs_arr = np.array(ridxs)[1:] # handles numpy slice correctly
|
||||
sliced = d.slice(ridxs_arr)
|
||||
np.testing.assert_equal(sliced.get_label(), y[2:7])
|
||||
|
||||
def test_feature_names_slice(self):
|
||||
data = np.random.randn(5, 5)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user