Support Series and Python primitives in inplace_predict and QDM (#8547)
This commit is contained in:
parent
a10e4cba4e
commit
f6effa1734
@ -2220,17 +2220,15 @@ class Booster:
|
||||
preds = ctypes.POINTER(ctypes.c_float)()
|
||||
|
||||
# once caching is supported, we can pass id(data) as cache id.
|
||||
args = {
|
||||
"type": 0,
|
||||
"training": False,
|
||||
"iteration_begin": iteration_range[0],
|
||||
"iteration_end": iteration_range[1],
|
||||
"missing": missing,
|
||||
"strict_shape": strict_shape,
|
||||
"cache_id": 0,
|
||||
}
|
||||
if predict_type == "margin":
|
||||
args["type"] = 1
|
||||
args = make_jcargs(
|
||||
type=1 if predict_type == "margin" else 0,
|
||||
training=False,
|
||||
iteration_begin=iteration_range[0],
|
||||
iteration_end=iteration_range[1],
|
||||
missing=missing,
|
||||
strict_shape=strict_shape,
|
||||
cache_id=0,
|
||||
)
|
||||
shape = ctypes.POINTER(c_bst_ulong)()
|
||||
dims = c_bst_ulong()
|
||||
|
||||
@ -2243,6 +2241,29 @@ class Booster:
|
||||
proxy = None
|
||||
p_handle = ctypes.c_void_p()
|
||||
assert proxy is None or isinstance(proxy, _ProxyDMatrix)
|
||||
|
||||
from .data import (
|
||||
_array_interface,
|
||||
_is_cudf_df,
|
||||
_is_cupy_array,
|
||||
_is_list,
|
||||
_is_pandas_df,
|
||||
_is_pandas_series,
|
||||
_is_tuple,
|
||||
_transform_pandas_df,
|
||||
)
|
||||
|
||||
enable_categorical = True
|
||||
if _is_pandas_series(data):
|
||||
import pandas as pd
|
||||
data = pd.DataFrame(data)
|
||||
if _is_pandas_df(data):
|
||||
data, fns, _ = _transform_pandas_df(data, enable_categorical)
|
||||
if validate_features:
|
||||
self._validate_features(fns)
|
||||
if _is_list(data) or _is_tuple(data):
|
||||
data = np.array(data)
|
||||
|
||||
if validate_features:
|
||||
if not hasattr(data, "shape"):
|
||||
raise TypeError(
|
||||
@ -2254,20 +2275,6 @@ class Booster:
|
||||
f"got {data.shape[1]}"
|
||||
)
|
||||
|
||||
from .data import (
|
||||
_array_interface,
|
||||
_is_cudf_df,
|
||||
_is_cupy_array,
|
||||
_is_pandas_df,
|
||||
_transform_pandas_df,
|
||||
)
|
||||
|
||||
enable_categorical = True
|
||||
if _is_pandas_df(data):
|
||||
data, fns, _ = _transform_pandas_df(data, enable_categorical)
|
||||
if validate_features:
|
||||
self._validate_features(fns)
|
||||
|
||||
if isinstance(data, np.ndarray):
|
||||
from .data import _ensure_np_dtype
|
||||
|
||||
@ -2276,7 +2283,7 @@ class Booster:
|
||||
_LIB.XGBoosterPredictFromDense(
|
||||
self.handle,
|
||||
_array_interface(data),
|
||||
from_pystr_to_cstr(json.dumps(args)),
|
||||
args,
|
||||
p_handle,
|
||||
ctypes.byref(shape),
|
||||
ctypes.byref(dims),
|
||||
@ -2293,7 +2300,7 @@ class Booster:
|
||||
_array_interface(csr.indices),
|
||||
_array_interface(csr.data),
|
||||
c_bst_ulong(csr.shape[1]),
|
||||
from_pystr_to_cstr(json.dumps(args)),
|
||||
args,
|
||||
p_handle,
|
||||
ctypes.byref(shape),
|
||||
ctypes.byref(dims),
|
||||
@ -2310,7 +2317,7 @@ class Booster:
|
||||
_LIB.XGBoosterPredictFromCudaArray(
|
||||
self.handle,
|
||||
interface_str,
|
||||
from_pystr_to_cstr(json.dumps(args)),
|
||||
args,
|
||||
p_handle,
|
||||
ctypes.byref(shape),
|
||||
ctypes.byref(dims),
|
||||
@ -2331,7 +2338,7 @@ class Booster:
|
||||
_LIB.XGBoosterPredictFromCudaColumnar(
|
||||
self.handle,
|
||||
interfaces_str,
|
||||
from_pystr_to_cstr(json.dumps(args)),
|
||||
args,
|
||||
p_handle,
|
||||
ctypes.byref(shape),
|
||||
ctypes.byref(dims),
|
||||
|
||||
@ -958,12 +958,12 @@ def dispatch_data_backend(
|
||||
return _from_list(data, missing, threads, feature_names, feature_types)
|
||||
if _is_tuple(data):
|
||||
return _from_tuple(data, missing, threads, feature_names, feature_types)
|
||||
if _is_pandas_df(data):
|
||||
return _from_pandas_df(data, enable_categorical, missing, threads,
|
||||
feature_names, feature_types)
|
||||
if _is_pandas_series(data):
|
||||
return _from_pandas_series(
|
||||
data, missing, threads, enable_categorical, feature_names, feature_types
|
||||
import pandas as pd
|
||||
data = pd.DataFrame(data)
|
||||
if _is_pandas_df(data):
|
||||
return _from_pandas_df(
|
||||
data, enable_categorical, missing, threads, feature_names, feature_types
|
||||
)
|
||||
if _is_cudf_df(data) or _is_cudf_ser(data):
|
||||
return _from_cudf_df(
|
||||
@ -1205,6 +1205,9 @@ def _proxy_transform(
|
||||
return data, None, feature_names, feature_types
|
||||
if _is_scipy_csr(data):
|
||||
return data, None, feature_names, feature_types
|
||||
if _is_pandas_series(data):
|
||||
import pandas as pd
|
||||
data = pd.DataFrame(data)
|
||||
if _is_pandas_df(data):
|
||||
arr, feature_names, feature_types = _transform_pandas_df(
|
||||
data, enable_categorical, feature_names, feature_types
|
||||
|
||||
@ -40,6 +40,7 @@ def np_dtypes(
|
||||
for dtype in dtypes:
|
||||
X = np.array(orig, dtype=dtype)
|
||||
yield orig, X
|
||||
yield orig.tolist(), X.tolist()
|
||||
|
||||
for dtype in dtypes:
|
||||
X = np.array(orig, dtype=dtype)
|
||||
@ -101,6 +102,11 @@ def pd_dtypes() -> Generator:
|
||||
{"f0": [1.0, 2.0, Null, 3.0], "f1": [3.0, 2.0, Null, 1.0]}, dtype=dtype
|
||||
)
|
||||
yield orig, df
|
||||
ser_orig = orig["f0"]
|
||||
ser = df["f0"]
|
||||
assert isinstance(ser, pd.Series)
|
||||
assert isinstance(ser_orig, pd.Series)
|
||||
yield ser_orig, ser
|
||||
|
||||
# Categorical
|
||||
orig = orig.astype("category")
|
||||
|
||||
@ -5,7 +5,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from scipy import sparse
|
||||
from xgboost.testing.data import np_dtypes
|
||||
from xgboost.testing.data import np_dtypes, pd_dtypes
|
||||
from xgboost.testing.shared import validate_leaf_output
|
||||
|
||||
import xgboost as xgb
|
||||
@ -231,6 +231,7 @@ class TestInplacePredict:
|
||||
from_dmatrix = booster.predict(dtrain)
|
||||
np.testing.assert_allclose(from_dmatrix, from_inplace)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_dtypes(self) -> None:
|
||||
for orig, x in np_dtypes(self.rows, self.cols):
|
||||
predt_orig = self.booster.inplace_predict(orig)
|
||||
@ -246,3 +247,17 @@ class TestInplacePredict:
|
||||
X: np.ndarray = np.array(orig, dtype=dtype)
|
||||
with pytest.raises(ValueError):
|
||||
self.booster.inplace_predict(X)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_pd_dtypes(self) -> None:
|
||||
from pandas.api.types import is_bool_dtype
|
||||
for orig, x in pd_dtypes():
|
||||
dtypes = orig.dtypes if isinstance(orig, pd.DataFrame) else [orig.dtypes]
|
||||
if isinstance(orig, pd.DataFrame) and is_bool_dtype(dtypes[0]):
|
||||
continue
|
||||
y = np.arange(x.shape[0])
|
||||
Xy = xgb.DMatrix(orig, y, enable_categorical=True)
|
||||
booster = xgb.train({"tree_method": "hist"}, Xy, num_boost_round=1)
|
||||
predt_orig = booster.inplace_predict(orig)
|
||||
predt = booster.inplace_predict(x)
|
||||
np.testing.assert_allclose(predt, predt_orig)
|
||||
|
||||
@ -298,22 +298,29 @@ class TestPandas:
|
||||
assert 'auc' not in cv.columns[0]
|
||||
assert 'error' in cv.columns[0]
|
||||
|
||||
def test_nullable_type(self) -> None:
|
||||
@pytest.mark.parametrize("DMatrixT", [xgb.DMatrix, xgb.QuantileDMatrix])
|
||||
def test_nullable_type(self, DMatrixT) -> None:
|
||||
from pandas.api.types import is_categorical
|
||||
|
||||
for DMatrixT in (xgb.DMatrix, xgb.QuantileDMatrix):
|
||||
for orig, df in pd_dtypes():
|
||||
if hasattr(df.dtypes, "__iter__"):
|
||||
enable_categorical = any(is_categorical for dtype in df.dtypes)
|
||||
else:
|
||||
# series
|
||||
enable_categorical = is_categorical(df.dtype)
|
||||
|
||||
m_orig = DMatrixT(orig, enable_categorical=enable_categorical)
|
||||
# extension types
|
||||
m_etype = DMatrixT(df, enable_categorical=enable_categorical)
|
||||
# different from pd.BooleanDtype(), None is converted to False with bool
|
||||
if any(dtype == "bool" for dtype in orig.dtypes):
|
||||
if hasattr(orig.dtypes, "__iter__") and any(
|
||||
dtype == "bool" for dtype in orig.dtypes
|
||||
):
|
||||
assert not tm.predictor_equal(m_orig, m_etype)
|
||||
else:
|
||||
assert tm.predictor_equal(m_orig, m_etype)
|
||||
|
||||
if isinstance(df, pd.DataFrame):
|
||||
f0 = df["f0"]
|
||||
with pytest.raises(ValueError, match="Label contains NaN"):
|
||||
xgb.DMatrix(df, f0, enable_categorical=enable_categorical)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user