Support pandas SparseArray. (#5431)

This commit is contained in:
Jiaming Yuan
2020-03-20 21:40:22 +08:00
committed by GitHub
parent 3cf665d3ec
commit abca9908ba
2 changed files with 25 additions and 5 deletions

View File

@@ -109,6 +109,22 @@ class TestPandas(unittest.TestCase):
assert dm.num_row() == 2
assert dm.num_col() == 6
def test_pandas_sparse(self):
import pandas as pd
rows = 100
X = pd.DataFrame(
{"A": pd.SparseArray(np.random.randint(0, 10, size=rows)),
"B": pd.SparseArray(np.random.randn(rows)),
"C": pd.SparseArray(np.random.permutation(
[True, False] * (rows // 2)))}
)
y = pd.Series(pd.SparseArray(np.random.randn(rows)))
dtrain = xgb.DMatrix(X, y)
booster = xgb.train({}, dtrain, num_boost_round=4)
predt_sparse = booster.predict(xgb.DMatrix(X))
predt_dense = booster.predict(xgb.DMatrix(X.sparse.to_dense()))
np.testing.assert_allclose(predt_sparse, predt_dense)
def test_pandas_label(self):
# label must be a single column
df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})