[Breaking] Remove num roots. (#5059)
This commit is contained in:
@@ -63,11 +63,11 @@ class TestDMatrix(unittest.TestCase):
|
||||
# Sliced UInt array
|
||||
z = np.array([12, 34, 56], np.uint32)[::2]
|
||||
dmat = xgb.DMatrix(np.array([[]]))
|
||||
dmat.set_uint_info('root_index', z)
|
||||
from_view = dmat.get_uint_info('root_index')
|
||||
dmat.set_uint_info('group', z)
|
||||
from_view = dmat.get_uint_info('group_ptr')
|
||||
dmat = xgb.DMatrix(np.array([[]]))
|
||||
dmat.set_uint_info('root_index', z + 0)
|
||||
from_array = dmat.get_uint_info('root_index')
|
||||
dmat.set_uint_info('group', z + 0)
|
||||
from_array = dmat.get_uint_info('group_ptr')
|
||||
assert (from_view.shape == from_array.shape)
|
||||
assert (from_view == from_array).all()
|
||||
|
||||
@@ -142,7 +142,7 @@ class TestDMatrix(unittest.TestCase):
|
||||
dtrain.get_float_info('label')
|
||||
dtrain.get_float_info('weight')
|
||||
dtrain.get_float_info('base_margin')
|
||||
dtrain.get_uint_info('root_index')
|
||||
dtrain.get_uint_info('group_ptr')
|
||||
|
||||
def test_sparse_dmatrix_csr(self):
|
||||
nrow = 100
|
||||
|
||||
Reference in New Issue
Block a user