Support pandas SparseArray. (#5431)
This commit is contained in:
@@ -109,6 +109,22 @@ class TestPandas(unittest.TestCase):
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 6
|
||||
|
||||
def test_pandas_sparse(self):
|
||||
import pandas as pd
|
||||
rows = 100
|
||||
X = pd.DataFrame(
|
||||
{"A": pd.SparseArray(np.random.randint(0, 10, size=rows)),
|
||||
"B": pd.SparseArray(np.random.randn(rows)),
|
||||
"C": pd.SparseArray(np.random.permutation(
|
||||
[True, False] * (rows // 2)))}
|
||||
)
|
||||
y = pd.Series(pd.SparseArray(np.random.randn(rows)))
|
||||
dtrain = xgb.DMatrix(X, y)
|
||||
booster = xgb.train({}, dtrain, num_boost_round=4)
|
||||
predt_sparse = booster.predict(xgb.DMatrix(X))
|
||||
predt_dense = booster.predict(xgb.DMatrix(X.sparse.to_dense()))
|
||||
np.testing.assert_allclose(predt_sparse, predt_dense)
|
||||
|
||||
def test_pandas_label(self):
|
||||
# label must be a single column
|
||||
df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
|
||||
|
||||
Reference in New Issue
Block a user