Fix DMatrix slice with feature types. (#6689)
This commit is contained in:
parent
218a5fb6dd
commit
5d48d40d9a
@ -301,6 +301,7 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
|||||||
out.feature_weigths.Copy(this->feature_weigths);
|
out.feature_weigths.Copy(this->feature_weigths);
|
||||||
|
|
||||||
out.feature_names = this->feature_names;
|
out.feature_names = this->feature_names;
|
||||||
|
out.feature_types.Resize(this->feature_types.Size());
|
||||||
out.feature_types.Copy(this->feature_types);
|
out.feature_types.Copy(this->feature_types);
|
||||||
out.feature_type_names = this->feature_type_names;
|
out.feature_type_names = this->feature_type_names;
|
||||||
|
|
||||||
|
|||||||
@ -18,9 +18,7 @@ rng = np.random.RandomState(1994)
|
|||||||
|
|
||||||
|
|
||||||
class TestPandas:
|
class TestPandas:
|
||||||
|
|
||||||
def test_pandas(self):
|
def test_pandas(self):
|
||||||
|
|
||||||
df = pd.DataFrame([[1, 2., True], [2, 3., False]],
|
df = pd.DataFrame([[1, 2., True], [2, 3., False]],
|
||||||
columns=['a', 'b', 'c'])
|
columns=['a', 'b', 'c'])
|
||||||
dm = xgb.DMatrix(df, label=pd.Series([1, 2]))
|
dm = xgb.DMatrix(df, label=pd.Series([1, 2]))
|
||||||
@ -110,6 +108,18 @@ class TestPandas:
|
|||||||
assert dm.num_row() == 2
|
assert dm.num_row() == 2
|
||||||
assert dm.num_col() == 6
|
assert dm.num_col() == 6
|
||||||
|
|
||||||
|
def test_slice(self):
|
||||||
|
rng = np.random.RandomState(1994)
|
||||||
|
rows = 100
|
||||||
|
X = rng.randint(3, 7, size=rows)
|
||||||
|
X = pd.DataFrame({'f0': X})
|
||||||
|
y = rng.randn(rows)
|
||||||
|
ridxs = [1, 2, 3, 4, 5, 6]
|
||||||
|
m = xgb.DMatrix(X, y)
|
||||||
|
sliced = m.slice(ridxs)
|
||||||
|
|
||||||
|
assert m.feature_types == sliced.feature_types
|
||||||
|
|
||||||
def test_pandas_categorical(self):
|
def test_pandas_categorical(self):
|
||||||
rng = np.random.RandomState(1994)
|
rng = np.random.RandomState(1994)
|
||||||
rows = 100
|
rows = 100
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user