Implement get_group. (#7564)
This commit is contained in:
parent
52277cc3da
commit
13b0fa4b97
@ -884,10 +884,20 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
Returns
|
||||
-------
|
||||
base_margin : float
|
||||
base_margin
|
||||
"""
|
||||
return self.get_float_info('base_margin')
|
||||
|
||||
def get_group(self) -> np.ndarray:
|
||||
"""Get the group of the DMatrix.
|
||||
|
||||
Returns
|
||||
-------
|
||||
group
|
||||
"""
|
||||
group_ptr = self.get_uint_info("group_ptr")
|
||||
return np.diff(group_ptr)
|
||||
|
||||
def num_row(self) -> int:
|
||||
"""Get the number of rows in the DMatrix.
|
||||
|
||||
|
||||
@ -285,6 +285,10 @@ class TestDMatrix:
|
||||
dtrain.get_float_info('base_margin')
|
||||
dtrain.get_uint_info('group_ptr')
|
||||
|
||||
group_len = np.array([2, 3, 4])
|
||||
dtrain.set_group(group_len)
|
||||
np.testing.assert_equal(group_len, dtrain.get_group())
|
||||
|
||||
def test_qid(self):
|
||||
rows = 100
|
||||
cols = 10
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user