Implement get_group. (#7564)
This commit is contained in:
@@ -285,6 +285,10 @@ class TestDMatrix:
|
||||
dtrain.get_float_info('base_margin')
|
||||
dtrain.get_uint_info('group_ptr')
|
||||
|
||||
group_len = np.array([2, 3, 4])
|
||||
dtrain.set_group(group_len)
|
||||
np.testing.assert_equal(group_len, dtrain.get_group())
|
||||
|
||||
def test_qid(self):
|
||||
rows = 100
|
||||
cols = 10
|
||||
|
||||
Reference in New Issue
Block a user