Fix DMatrix slice with feature types. (#6689)

This commit is contained in:
Jiaming Yuan 2021-02-09 08:13:51 +08:00 committed by GitHub
parent 218a5fb6dd
commit 5d48d40d9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 2 deletions

View File

@ -301,6 +301,7 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
out.feature_weigths.Copy(this->feature_weigths);
out.feature_names = this->feature_names;
out.feature_types.Resize(this->feature_types.Size());
out.feature_types.Copy(this->feature_types);
out.feature_type_names = this->feature_type_names;

View File

@ -18,9 +18,7 @@ rng = np.random.RandomState(1994)
class TestPandas:
def test_pandas(self):
df = pd.DataFrame([[1, 2., True], [2, 3., False]],
columns=['a', 'b', 'c'])
dm = xgb.DMatrix(df, label=pd.Series([1, 2]))
@ -110,6 +108,18 @@ class TestPandas:
assert dm.num_row() == 2
assert dm.num_col() == 6
def test_slice(self):
rng = np.random.RandomState(1994)
rows = 100
X = rng.randint(3, 7, size=rows)
X = pd.DataFrame({'f0': X})
y = rng.randn(rows)
ridxs = [1, 2, 3, 4, 5, 6]
m = xgb.DMatrix(X, y)
sliced = m.slice(ridxs)
assert m.feature_types == sliced.feature_types
def test_pandas_categorical(self):
rng = np.random.RandomState(1994)
rows = 100