Fix slice and get info. (#5552)
This commit is contained in:
@@ -71,7 +71,34 @@ class TestDMatrix(unittest.TestCase):
|
||||
assert (from_view.shape == from_array.shape)
|
||||
assert (from_view == from_array).all()
|
||||
|
||||
def test_feature_names(self):
|
||||
def test_slice(self):
|
||||
X = rng.randn(100, 100)
|
||||
y = rng.randint(low=0, high=3, size=100)
|
||||
d = xgb.DMatrix(X, y)
|
||||
eval_res_0 = {}
|
||||
booster = xgb.train(
|
||||
{'num_class': 3, 'objective': 'multi:softprob'}, d,
|
||||
num_boost_round=2, evals=[(d, 'd')], evals_result=eval_res_0)
|
||||
|
||||
predt = booster.predict(d)
|
||||
predt = predt.reshape(100 * 3, 1)
|
||||
d.set_base_margin(predt)
|
||||
|
||||
ridxs = [1, 2, 3, 4, 5, 6]
|
||||
d = d.slice(ridxs)
|
||||
sliced_margin = d.get_float_info('base_margin')
|
||||
assert sliced_margin.shape[0] == len(ridxs) * 3
|
||||
|
||||
eval_res_1 = {}
|
||||
xgb.train({'num_class': 3, 'objective': 'multi:softprob'}, d,
|
||||
num_boost_round=2, evals=[(d, 'd')], evals_result=eval_res_1)
|
||||
|
||||
eval_res_0 = eval_res_0['d']['merror']
|
||||
eval_res_1 = eval_res_1['d']['merror']
|
||||
for i in range(len(eval_res_0)):
|
||||
assert abs(eval_res_0[i] - eval_res_1[i]) < 0.02
|
||||
|
||||
def test_feature_names_slice(self):
|
||||
data = np.random.randn(5, 5)
|
||||
|
||||
# different length
|
||||
|
||||
Reference in New Issue
Block a user