Implement get_group. (#7564)

This commit is contained in:
Jiaming Yuan
2022-01-16 02:07:42 +08:00
committed by GitHub
parent 52277cc3da
commit 13b0fa4b97
2 changed files with 15 additions and 1 deletions

View File

@@ -285,6 +285,10 @@ class TestDMatrix:
dtrain.get_float_info('base_margin')
dtrain.get_uint_info('group_ptr')
group_len = np.array([2, 3, 4])
dtrain.set_group(group_len)
np.testing.assert_equal(group_len, dtrain.get_group())
def test_qid(self):
rows = 100
cols = 10