Support primitive types of pyarrow-backed pandas dataframe. (#8653)
Categorical data (dictionary) is not supported at the moment.
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
from typing import Type
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from test_dmatrix import set_base_margin_info
|
||||
from xgboost.testing.data import pd_dtypes
|
||||
from xgboost.testing.data import pd_arrow_dtypes, pd_dtypes
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
@@ -305,9 +307,17 @@ class TestPandas:
|
||||
# series
|
||||
enable_categorical = is_categorical(df.dtype)
|
||||
|
||||
m_orig = DMatrixT(orig, enable_categorical=enable_categorical)
|
||||
f0_orig = orig[orig.columns[0]] if isinstance(orig, pd.DataFrame) else orig
|
||||
f0 = df[df.columns[0]] if isinstance(df, pd.DataFrame) else df
|
||||
y_orig = f0_orig.astype(pd.Float32Dtype()).fillna(0)
|
||||
y = f0.astype(pd.Float32Dtype()).fillna(0)
|
||||
|
||||
m_orig = DMatrixT(orig, enable_categorical=enable_categorical, label=y_orig)
|
||||
# extension types
|
||||
m_etype = DMatrixT(df, enable_categorical=enable_categorical)
|
||||
copy = df.copy()
|
||||
m_etype = DMatrixT(df, enable_categorical=enable_categorical, label=y)
|
||||
# no mutation
|
||||
assert df.equals(copy)
|
||||
# different from pd.BooleanDtype(), None is converted to False with bool
|
||||
if hasattr(orig.dtypes, "__iter__") and any(
|
||||
dtype == "bool" for dtype in orig.dtypes
|
||||
@@ -316,7 +326,32 @@ class TestPandas:
|
||||
else:
|
||||
assert tm.predictor_equal(m_orig, m_etype)
|
||||
|
||||
np.testing.assert_allclose(m_orig.get_label(), m_etype.get_label())
|
||||
np.testing.assert_allclose(m_etype.get_label(), y.values.astype(np.float32))
|
||||
|
||||
if isinstance(df, pd.DataFrame):
|
||||
f0 = df["f0"]
|
||||
with pytest.raises(ValueError, match="Label contains NaN"):
|
||||
xgb.DMatrix(df, f0, enable_categorical=enable_categorical)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_arrow())
|
||||
@pytest.mark.parametrize("DMatrixT", [xgb.DMatrix, xgb.QuantileDMatrix])
|
||||
def test_pyarrow_type(self, DMatrixT: Type[xgb.DMatrix]) -> None:
|
||||
for orig, df in pd_arrow_dtypes():
|
||||
f0_orig: pd.Series = orig["f0"]
|
||||
f0 = df["f0"]
|
||||
|
||||
if f0.dtype.name.startswith("bool"):
|
||||
y = None
|
||||
y_orig = None
|
||||
else:
|
||||
y_orig = f0_orig.fillna(0, inplace=False)
|
||||
y = f0.fillna(0, inplace=False)
|
||||
|
||||
m_orig = DMatrixT(orig, enable_categorical=True, label=y_orig)
|
||||
m_etype = DMatrixT(df, enable_categorical=True, label=y)
|
||||
|
||||
assert tm.predictor_equal(m_orig, m_etype)
|
||||
if y is not None:
|
||||
np.testing.assert_allclose(m_orig.get_label(), m_etype.get_label())
|
||||
np.testing.assert_allclose(m_etype.get_label(), y.values)
|
||||
|
||||
Reference in New Issue
Block a user