Accept numpy array for DMatrix slice index. (#6368)

This commit is contained in:
Jiaming Yuan
2020-12-16 14:42:52 +08:00
committed by GitHub
parent ef4a0e0aac
commit 347f593169
3 changed files with 28 additions and 14 deletions

View File

@@ -145,6 +145,10 @@ class TestDMatrix:
num_boost_round=2, evals=[(d2, 'd2'), (sliced, 'sliced')], evals_result=eval_res)
np.testing.assert_equal(eval_res['d2']['mlogloss'], eval_res['sliced']['mlogloss'])
ridxs_arr = np.array(ridxs)[1:] # handles numpy slice correctly
sliced = d.slice(ridxs_arr)
np.testing.assert_equal(sliced.get_label(), y[2:7])
def test_feature_names_slice(self):
data = np.random.randn(5, 5)