Support Series and Python primitives in inplace_predict and QDM (#8547)

This commit is contained in:
Jiaming Yuan 2022-12-17 00:15:15 +08:00 committed by GitHub
parent a10e4cba4e
commit f6effa1734
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 84 additions and 46 deletions

View File

@ -2220,17 +2220,15 @@ class Booster:
preds = ctypes.POINTER(ctypes.c_float)()
# once caching is supported, we can pass id(data) as cache id.
args = {
"type": 0,
"training": False,
"iteration_begin": iteration_range[0],
"iteration_end": iteration_range[1],
"missing": missing,
"strict_shape": strict_shape,
"cache_id": 0,
}
if predict_type == "margin":
args["type"] = 1
args = make_jcargs(
type=1 if predict_type == "margin" else 0,
training=False,
iteration_begin=iteration_range[0],
iteration_end=iteration_range[1],
missing=missing,
strict_shape=strict_shape,
cache_id=0,
)
shape = ctypes.POINTER(c_bst_ulong)()
dims = c_bst_ulong()
@ -2243,6 +2241,29 @@ class Booster:
proxy = None
p_handle = ctypes.c_void_p()
assert proxy is None or isinstance(proxy, _ProxyDMatrix)
from .data import (
_array_interface,
_is_cudf_df,
_is_cupy_array,
_is_list,
_is_pandas_df,
_is_pandas_series,
_is_tuple,
_transform_pandas_df,
)
enable_categorical = True
if _is_pandas_series(data):
import pandas as pd
data = pd.DataFrame(data)
if _is_pandas_df(data):
data, fns, _ = _transform_pandas_df(data, enable_categorical)
if validate_features:
self._validate_features(fns)
if _is_list(data) or _is_tuple(data):
data = np.array(data)
if validate_features:
if not hasattr(data, "shape"):
raise TypeError(
@ -2254,20 +2275,6 @@ class Booster:
f"got {data.shape[1]}"
)
from .data import (
_array_interface,
_is_cudf_df,
_is_cupy_array,
_is_pandas_df,
_transform_pandas_df,
)
enable_categorical = True
if _is_pandas_df(data):
data, fns, _ = _transform_pandas_df(data, enable_categorical)
if validate_features:
self._validate_features(fns)
if isinstance(data, np.ndarray):
from .data import _ensure_np_dtype
@ -2276,7 +2283,7 @@ class Booster:
_LIB.XGBoosterPredictFromDense(
self.handle,
_array_interface(data),
from_pystr_to_cstr(json.dumps(args)),
args,
p_handle,
ctypes.byref(shape),
ctypes.byref(dims),
@ -2293,7 +2300,7 @@ class Booster:
_array_interface(csr.indices),
_array_interface(csr.data),
c_bst_ulong(csr.shape[1]),
from_pystr_to_cstr(json.dumps(args)),
args,
p_handle,
ctypes.byref(shape),
ctypes.byref(dims),
@ -2310,7 +2317,7 @@ class Booster:
_LIB.XGBoosterPredictFromCudaArray(
self.handle,
interface_str,
from_pystr_to_cstr(json.dumps(args)),
args,
p_handle,
ctypes.byref(shape),
ctypes.byref(dims),
@ -2331,7 +2338,7 @@ class Booster:
_LIB.XGBoosterPredictFromCudaColumnar(
self.handle,
interfaces_str,
from_pystr_to_cstr(json.dumps(args)),
args,
p_handle,
ctypes.byref(shape),
ctypes.byref(dims),

View File

@ -958,12 +958,12 @@ def dispatch_data_backend(
return _from_list(data, missing, threads, feature_names, feature_types)
if _is_tuple(data):
return _from_tuple(data, missing, threads, feature_names, feature_types)
if _is_pandas_df(data):
return _from_pandas_df(data, enable_categorical, missing, threads,
feature_names, feature_types)
if _is_pandas_series(data):
return _from_pandas_series(
data, missing, threads, enable_categorical, feature_names, feature_types
import pandas as pd
data = pd.DataFrame(data)
if _is_pandas_df(data):
return _from_pandas_df(
data, enable_categorical, missing, threads, feature_names, feature_types
)
if _is_cudf_df(data) or _is_cudf_ser(data):
return _from_cudf_df(
@ -1205,6 +1205,9 @@ def _proxy_transform(
return data, None, feature_names, feature_types
if _is_scipy_csr(data):
return data, None, feature_names, feature_types
if _is_pandas_series(data):
import pandas as pd
data = pd.DataFrame(data)
if _is_pandas_df(data):
arr, feature_names, feature_types = _transform_pandas_df(
data, enable_categorical, feature_names, feature_types

View File

@ -40,6 +40,7 @@ def np_dtypes(
for dtype in dtypes:
X = np.array(orig, dtype=dtype)
yield orig, X
yield orig.tolist(), X.tolist()
for dtype in dtypes:
X = np.array(orig, dtype=dtype)
@ -101,6 +102,11 @@ def pd_dtypes() -> Generator:
{"f0": [1.0, 2.0, Null, 3.0], "f1": [3.0, 2.0, Null, 1.0]}, dtype=dtype
)
yield orig, df
ser_orig = orig["f0"]
ser = df["f0"]
assert isinstance(ser, pd.Series)
assert isinstance(ser_orig, pd.Series)
yield ser_orig, ser
# Categorical
orig = orig.astype("category")

View File

@ -5,7 +5,7 @@ import numpy as np
import pandas as pd
import pytest
from scipy import sparse
from xgboost.testing.data import np_dtypes
from xgboost.testing.data import np_dtypes, pd_dtypes
from xgboost.testing.shared import validate_leaf_output
import xgboost as xgb
@ -231,6 +231,7 @@ class TestInplacePredict:
from_dmatrix = booster.predict(dtrain)
np.testing.assert_allclose(from_dmatrix, from_inplace)
@pytest.mark.skipif(**tm.no_pandas())
def test_dtypes(self) -> None:
for orig, x in np_dtypes(self.rows, self.cols):
predt_orig = self.booster.inplace_predict(orig)
@ -246,3 +247,17 @@ class TestInplacePredict:
X: np.ndarray = np.array(orig, dtype=dtype)
with pytest.raises(ValueError):
self.booster.inplace_predict(X)
@pytest.mark.skipif(**tm.no_pandas())
def test_pd_dtypes(self) -> None:
from pandas.api.types import is_bool_dtype
for orig, x in pd_dtypes():
dtypes = orig.dtypes if isinstance(orig, pd.DataFrame) else [orig.dtypes]
if isinstance(orig, pd.DataFrame) and is_bool_dtype(dtypes[0]):
continue
y = np.arange(x.shape[0])
Xy = xgb.DMatrix(orig, y, enable_categorical=True)
booster = xgb.train({"tree_method": "hist"}, Xy, num_boost_round=1)
predt_orig = booster.inplace_predict(orig)
predt = booster.inplace_predict(x)
np.testing.assert_allclose(predt, predt_orig)

View File

@ -298,22 +298,29 @@ class TestPandas:
assert 'auc' not in cv.columns[0]
assert 'error' in cv.columns[0]
def test_nullable_type(self) -> None:
@pytest.mark.parametrize("DMatrixT", [xgb.DMatrix, xgb.QuantileDMatrix])
def test_nullable_type(self, DMatrixT) -> None:
from pandas.api.types import is_categorical
for DMatrixT in (xgb.DMatrix, xgb.QuantileDMatrix):
for orig, df in pd_dtypes():
for orig, df in pd_dtypes():
if hasattr(df.dtypes, "__iter__"):
enable_categorical = any(is_categorical for dtype in df.dtypes)
else:
# series
enable_categorical = is_categorical(df.dtype)
m_orig = DMatrixT(orig, enable_categorical=enable_categorical)
# extension types
m_etype = DMatrixT(df, enable_categorical=enable_categorical)
# different from pd.BooleanDtype(), None is converted to False with bool
if any(dtype == "bool" for dtype in orig.dtypes):
assert not tm.predictor_equal(m_orig, m_etype)
else:
assert tm.predictor_equal(m_orig, m_etype)
m_orig = DMatrixT(orig, enable_categorical=enable_categorical)
# extension types
m_etype = DMatrixT(df, enable_categorical=enable_categorical)
# different from pd.BooleanDtype(), None is converted to False with bool
if hasattr(orig.dtypes, "__iter__") and any(
dtype == "bool" for dtype in orig.dtypes
):
assert not tm.predictor_equal(m_orig, m_etype)
else:
assert tm.predictor_equal(m_orig, m_etype)
if isinstance(df, pd.DataFrame):
f0 = df["f0"]
with pytest.raises(ValueError, match="Label contains NaN"):
xgb.DMatrix(df, f0, enable_categorical=enable_categorical)