Accept numpy array for DMatrix slice index. (#6368)
This commit is contained in:
parent
ef4a0e0aac
commit
347f593169
@ -755,35 +755,40 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
number of columns : int
|
number of columns : int
|
||||||
"""
|
"""
|
||||||
ret = c_bst_ulong()
|
ret = c_bst_ulong()
|
||||||
_check_call(_LIB.XGDMatrixNumCol(self.handle,
|
_check_call(_LIB.XGDMatrixNumCol(self.handle, ctypes.byref(ret)))
|
||||||
ctypes.byref(ret)))
|
|
||||||
return ret.value
|
return ret.value
|
||||||
|
|
||||||
def slice(self, rindex, allow_groups=False):
|
def slice(
|
||||||
|
self, rindex: Union[List[int], np.ndarray], allow_groups: bool = False
|
||||||
|
) -> "DMatrix":
|
||||||
"""Slice the DMatrix and return a new DMatrix that only contains `rindex`.
|
"""Slice the DMatrix and return a new DMatrix that only contains `rindex`.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
rindex : list
|
rindex
|
||||||
List of indices to be selected.
|
List of indices to be selected.
|
||||||
allow_groups : boolean
|
allow_groups
|
||||||
Allow slicing of a matrix with a groups attribute
|
Allow slicing of a matrix with a groups attribute
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
res : DMatrix
|
res
|
||||||
A new DMatrix containing only selected indices.
|
A new DMatrix containing only selected indices.
|
||||||
"""
|
"""
|
||||||
|
from .data import _maybe_np_slice
|
||||||
|
|
||||||
res = DMatrix(None)
|
res = DMatrix(None)
|
||||||
res.handle = ctypes.c_void_p()
|
res.handle = ctypes.c_void_p()
|
||||||
_check_call(_LIB.XGDMatrixSliceDMatrixEx(
|
rindex = _maybe_np_slice(rindex, dtype=np.int32)
|
||||||
|
_check_call(
|
||||||
|
_LIB.XGDMatrixSliceDMatrixEx(
|
||||||
self.handle,
|
self.handle,
|
||||||
c_array(ctypes.c_int, rindex),
|
c_array(ctypes.c_int, rindex),
|
||||||
c_bst_ulong(len(rindex)),
|
c_bst_ulong(len(rindex)),
|
||||||
ctypes.byref(res.handle),
|
ctypes.byref(res.handle),
|
||||||
ctypes.c_int(1 if allow_groups else 0)))
|
ctypes.c_int(1 if allow_groups else 0),
|
||||||
res.feature_names = self.feature_names
|
)
|
||||||
res.feature_types = self.feature_types
|
)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -299,6 +299,11 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
|||||||
|
|
||||||
out.feature_weigths.Resize(this->feature_weigths.Size());
|
out.feature_weigths.Resize(this->feature_weigths.Size());
|
||||||
out.feature_weigths.Copy(this->feature_weigths);
|
out.feature_weigths.Copy(this->feature_weigths);
|
||||||
|
|
||||||
|
out.feature_names = this->feature_names;
|
||||||
|
out.feature_types.Copy(this->feature_types);
|
||||||
|
out.feature_type_names = this->feature_type_names;
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -145,6 +145,10 @@ class TestDMatrix:
|
|||||||
num_boost_round=2, evals=[(d2, 'd2'), (sliced, 'sliced')], evals_result=eval_res)
|
num_boost_round=2, evals=[(d2, 'd2'), (sliced, 'sliced')], evals_result=eval_res)
|
||||||
np.testing.assert_equal(eval_res['d2']['mlogloss'], eval_res['sliced']['mlogloss'])
|
np.testing.assert_equal(eval_res['d2']['mlogloss'], eval_res['sliced']['mlogloss'])
|
||||||
|
|
||||||
|
ridxs_arr = np.array(ridxs)[1:] # handles numpy slice correctly
|
||||||
|
sliced = d.slice(ridxs_arr)
|
||||||
|
np.testing.assert_equal(sliced.get_label(), y[2:7])
|
||||||
|
|
||||||
def test_feature_names_slice(self):
|
def test_feature_names_slice(self):
|
||||||
data = np.random.randn(5, 5)
|
data = np.random.randn(5, 5)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user