Implement get_group. (#7564)

This commit is contained in:
Jiaming Yuan
2022-01-16 02:07:42 +08:00
committed by GitHub
parent 52277cc3da
commit 13b0fa4b97
2 changed files with 15 additions and 1 deletions

View File

@@ -884,10 +884,20 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
Returns
-------
base_margin : float
base_margin
"""
return self.get_float_info('base_margin')
def get_group(self) -> np.ndarray:
"""Get the group of the DMatrix.
Returns
-------
group
"""
group_ptr = self.get_uint_info("group_ptr")
return np.diff(group_ptr)
def num_row(self) -> int:
"""Get the number of rows in the DMatrix.