Implement get_group. (#7564)
This commit is contained in:
@@ -884,10 +884,20 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
Returns
|
||||
-------
|
||||
base_margin : float
|
||||
base_margin
|
||||
"""
|
||||
return self.get_float_info('base_margin')
|
||||
|
||||
def get_group(self) -> np.ndarray:
|
||||
"""Get the group of the DMatrix.
|
||||
|
||||
Returns
|
||||
-------
|
||||
group
|
||||
"""
|
||||
group_ptr = self.get_uint_info("group_ptr")
|
||||
return np.diff(group_ptr)
|
||||
|
||||
def num_row(self) -> int:
|
||||
"""Get the number of rows in the DMatrix.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user