Implement get_group. (#7564)

This commit is contained in:
Jiaming Yuan 2022-01-16 02:07:42 +08:00 committed by GitHub
parent 52277cc3da
commit 13b0fa4b97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 1 deletions

View File

@ -884,10 +884,20 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
Returns
-------
base_margin : float
base_margin
"""
return self.get_float_info('base_margin')
def get_group(self) -> np.ndarray:
"""Get the group of the DMatrix.
Returns
-------
group
"""
group_ptr = self.get_uint_info("group_ptr")
return np.diff(group_ptr)
def num_row(self) -> int:
"""Get the number of rows in the DMatrix.

View File

@ -285,6 +285,10 @@ class TestDMatrix:
dtrain.get_float_info('base_margin')
dtrain.get_uint_info('group_ptr')
group_len = np.array([2, 3, 4])
dtrain.set_group(group_len)
np.testing.assert_equal(group_len, dtrain.get_group())
def test_qid(self):
rows = 100
cols = 10