diff --git a/src/data/data.cc b/src/data/data.cc index 60afc0511..0e90584e0 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -301,6 +301,7 @@ MetaInfo MetaInfo::Slice(common::Span ridxs) const { out.feature_weigths.Copy(this->feature_weigths); out.feature_names = this->feature_names; + out.feature_types.Resize(this->feature_types.Size()); out.feature_types.Copy(this->feature_types); out.feature_type_names = this->feature_type_names; diff --git a/tests/python/test_with_pandas.py b/tests/python/test_with_pandas.py index df79b09fc..c9400bfa6 100644 --- a/tests/python/test_with_pandas.py +++ b/tests/python/test_with_pandas.py @@ -18,9 +18,7 @@ rng = np.random.RandomState(1994) class TestPandas: - def test_pandas(self): - df = pd.DataFrame([[1, 2., True], [2, 3., False]], columns=['a', 'b', 'c']) dm = xgb.DMatrix(df, label=pd.Series([1, 2])) @@ -110,6 +108,18 @@ class TestPandas: assert dm.num_row() == 2 assert dm.num_col() == 6 + def test_slice(self): + rng = np.random.RandomState(1994) + rows = 100 + X = rng.randint(3, 7, size=rows) + X = pd.DataFrame({'f0': X}) + y = rng.randn(rows) + ridxs = [1, 2, 3, 4, 5, 6] + m = xgb.DMatrix(X, y) + sliced = m.slice(ridxs) + + assert m.feature_types == sliced.feature_types + def test_pandas_categorical(self): rng = np.random.RandomState(1994) rows = 100