Support arrow through pandas ext types. (#9612)

- Use pandas extension type for pyarrow support.
- Additional support for QDM.
- Additional support for inplace_predict.
This commit is contained in:
Jiaming Yuan
2023-09-28 17:00:16 +08:00
committed by GitHub
parent 3f2093fb81
commit 60526100e3
11 changed files with 74 additions and 584 deletions

View File

@@ -22,7 +22,7 @@ pytestmark = pytest.mark.skipif(
dpath = "demo/data/"
class TestArrowTable(unittest.TestCase):
class TestArrowTable:
def test_arrow_table(self):
df = pd.DataFrame(
[[0, 1, 2.0, 3.0], [1, 2, 3.0, 4.0]], columns=["a", "b", "c", "d"]
@@ -52,7 +52,8 @@ class TestArrowTable(unittest.TestCase):
assert dm.num_row() == 4
assert dm.num_col() == 3
def test_arrow_train(self):
@pytest.mark.parametrize("DMatrixT", [xgb.DMatrix, xgb.QuantileDMatrix])
def test_arrow_train(self, DMatrixT):
import pandas as pd
rows = 100
@@ -64,16 +65,24 @@ class TestArrowTable(unittest.TestCase):
}
)
y = pd.Series(np.random.randn(rows))
table = pa.Table.from_pandas(X)
dtrain1 = xgb.DMatrix(table)
dtrain1.set_label(y)
dtrain1 = DMatrixT(table)
dtrain1.set_label(pa.Table.from_pandas(pd.DataFrame(y)))
bst1 = xgb.train({}, dtrain1, num_boost_round=10)
preds1 = bst1.predict(xgb.DMatrix(X))
dtrain2 = xgb.DMatrix(X, y)
preds1 = bst1.predict(DMatrixT(X))
dtrain2 = DMatrixT(X, y)
bst2 = xgb.train({}, dtrain2, num_boost_round=10)
preds2 = bst2.predict(xgb.DMatrix(X))
preds2 = bst2.predict(DMatrixT(X))
np.testing.assert_allclose(preds1, preds2)
preds3 = bst2.inplace_predict(table)
np.testing.assert_allclose(preds1, preds3)
assert bst2.feature_names == ["A", "B", "C"]
assert bst2.feature_types == ["int", "float", "int"]
def test_arrow_survival(self):
data = os.path.join(tm.data_dir(__file__), "veterans_lung_cancer.csv")
table = pc.read_csv(data)