Support arrow through pandas ext types. (#9612)
- Use pandas extension type for pyarrow support. - Additional support for QDM. - Additional support for inplace_predict.
This commit is contained in:
@@ -22,7 +22,7 @@ pytestmark = pytest.mark.skipif(
|
||||
dpath = "demo/data/"
|
||||
|
||||
|
||||
class TestArrowTable(unittest.TestCase):
|
||||
class TestArrowTable:
|
||||
def test_arrow_table(self):
|
||||
df = pd.DataFrame(
|
||||
[[0, 1, 2.0, 3.0], [1, 2, 3.0, 4.0]], columns=["a", "b", "c", "d"]
|
||||
@@ -52,7 +52,8 @@ class TestArrowTable(unittest.TestCase):
|
||||
assert dm.num_row() == 4
|
||||
assert dm.num_col() == 3
|
||||
|
||||
def test_arrow_train(self):
|
||||
@pytest.mark.parametrize("DMatrixT", [xgb.DMatrix, xgb.QuantileDMatrix])
|
||||
def test_arrow_train(self, DMatrixT):
|
||||
import pandas as pd
|
||||
|
||||
rows = 100
|
||||
@@ -64,16 +65,24 @@ class TestArrowTable(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
y = pd.Series(np.random.randn(rows))
|
||||
|
||||
table = pa.Table.from_pandas(X)
|
||||
dtrain1 = xgb.DMatrix(table)
|
||||
dtrain1.set_label(y)
|
||||
dtrain1 = DMatrixT(table)
|
||||
dtrain1.set_label(pa.Table.from_pandas(pd.DataFrame(y)))
|
||||
bst1 = xgb.train({}, dtrain1, num_boost_round=10)
|
||||
preds1 = bst1.predict(xgb.DMatrix(X))
|
||||
dtrain2 = xgb.DMatrix(X, y)
|
||||
preds1 = bst1.predict(DMatrixT(X))
|
||||
|
||||
dtrain2 = DMatrixT(X, y)
|
||||
bst2 = xgb.train({}, dtrain2, num_boost_round=10)
|
||||
preds2 = bst2.predict(xgb.DMatrix(X))
|
||||
preds2 = bst2.predict(DMatrixT(X))
|
||||
|
||||
np.testing.assert_allclose(preds1, preds2)
|
||||
|
||||
preds3 = bst2.inplace_predict(table)
|
||||
np.testing.assert_allclose(preds1, preds3)
|
||||
assert bst2.feature_names == ["A", "B", "C"]
|
||||
assert bst2.feature_types == ["int", "float", "int"]
|
||||
|
||||
def test_arrow_survival(self):
|
||||
data = os.path.join(tm.data_dir(__file__), "veterans_lung_cancer.csv")
|
||||
table = pc.read_csv(data)
|
||||
|
||||
Reference in New Issue
Block a user