From 13b0fa4b97977660c5ada88e898ae8edd61640ec Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sun, 16 Jan 2022 02:07:42 +0800 Subject: [PATCH] Implement `get_group`. (#7564) --- python-package/xgboost/core.py | 12 +++++++++++- tests/python/test_dmatrix.py | 4 ++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 46624fd34..0d14d29dd 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -884,10 +884,20 @@ class DMatrix: # pylint: disable=too-many-instance-attributes Returns ------- - base_margin : float + base_margin """ return self.get_float_info('base_margin') + def get_group(self) -> np.ndarray: + """Get the group of the DMatrix. + + Returns + ------- + group + """ + group_ptr = self.get_uint_info("group_ptr") + return np.diff(group_ptr) + def num_row(self) -> int: """Get the number of rows in the DMatrix. diff --git a/tests/python/test_dmatrix.py b/tests/python/test_dmatrix.py index 6cd026d19..2571d5322 100644 --- a/tests/python/test_dmatrix.py +++ b/tests/python/test_dmatrix.py @@ -285,6 +285,10 @@ class TestDMatrix: dtrain.get_float_info('base_margin') dtrain.get_uint_info('group_ptr') + group_len = np.array([2, 3, 4]) + dtrain.set_group(group_len) + np.testing.assert_equal(group_len, dtrain.get_group()) + def test_qid(self): rows = 100 cols = 10