Fix DMatrix slice with feature types. (#6689)

This commit is contained in:
Jiaming Yuan
2021-02-09 08:13:51 +08:00
committed by GitHub
parent 218a5fb6dd
commit 5d48d40d9a
2 changed files with 13 additions and 2 deletions

View File

@@ -18,9 +18,7 @@ rng = np.random.RandomState(1994)
class TestPandas:
def test_pandas(self):
df = pd.DataFrame([[1, 2., True], [2, 3., False]],
columns=['a', 'b', 'c'])
dm = xgb.DMatrix(df, label=pd.Series([1, 2]))
@@ -110,6 +108,18 @@ class TestPandas:
assert dm.num_row() == 2
assert dm.num_col() == 6
def test_slice(self):
rng = np.random.RandomState(1994)
rows = 100
X = rng.randint(3, 7, size=rows)
X = pd.DataFrame({'f0': X})
y = rng.randn(rows)
ridxs = [1, 2, 3, 4, 5, 6]
m = xgb.DMatrix(X, y)
sliced = m.slice(ridxs)
assert m.feature_types == sliced.feature_types
def test_pandas_categorical(self):
rng = np.random.RandomState(1994)
rows = 100