Implement get_group. (#7564)
This commit is contained in:
parent
52277cc3da
commit
13b0fa4b97
@ -884,10 +884,20 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
base_margin : float
|
base_margin
|
||||||
"""
|
"""
|
||||||
return self.get_float_info('base_margin')
|
return self.get_float_info('base_margin')
|
||||||
|
|
||||||
|
def get_group(self) -> np.ndarray:
|
||||||
|
"""Get the group of the DMatrix.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
group
|
||||||
|
"""
|
||||||
|
group_ptr = self.get_uint_info("group_ptr")
|
||||||
|
return np.diff(group_ptr)
|
||||||
|
|
||||||
def num_row(self) -> int:
|
def num_row(self) -> int:
|
||||||
"""Get the number of rows in the DMatrix.
|
"""Get the number of rows in the DMatrix.
|
||||||
|
|
||||||
|
|||||||
@ -285,6 +285,10 @@ class TestDMatrix:
|
|||||||
dtrain.get_float_info('base_margin')
|
dtrain.get_float_info('base_margin')
|
||||||
dtrain.get_uint_info('group_ptr')
|
dtrain.get_uint_info('group_ptr')
|
||||||
|
|
||||||
|
group_len = np.array([2, 3, 4])
|
||||||
|
dtrain.set_group(group_len)
|
||||||
|
np.testing.assert_equal(group_len, dtrain.get_group())
|
||||||
|
|
||||||
def test_qid(self):
|
def test_qid(self):
|
||||||
rows = 100
|
rows = 100
|
||||||
cols = 10
|
cols = 10
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user