diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 46624fd34..0d14d29dd 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -884,10 +884,20 @@ class DMatrix: # pylint: disable=too-many-instance-attributes Returns ------- - base_margin : float + base_margin """ return self.get_float_info('base_margin') + def get_group(self) -> np.ndarray: + """Get the group of the DMatrix. + + Returns + ------- + group + """ + group_ptr = self.get_uint_info("group_ptr") + return np.diff(group_ptr) + def num_row(self) -> int: """Get the number of rows in the DMatrix. diff --git a/tests/python/test_dmatrix.py b/tests/python/test_dmatrix.py index 6cd026d19..2571d5322 100644 --- a/tests/python/test_dmatrix.py +++ b/tests/python/test_dmatrix.py @@ -285,6 +285,10 @@ class TestDMatrix: dtrain.get_float_info('base_margin') dtrain.get_uint_info('group_ptr') + group_len = np.array([2, 3, 4]) + dtrain.set_group(group_len) + np.testing.assert_equal(group_len, dtrain.get_group()) + def test_qid(self): rows = 100 cols = 10