Support Series and Python primitives in inplace_predict and QDM (#8547)
This commit is contained in:
parent
a10e4cba4e
commit
f6effa1734
@ -2220,17 +2220,15 @@ class Booster:
|
|||||||
preds = ctypes.POINTER(ctypes.c_float)()
|
preds = ctypes.POINTER(ctypes.c_float)()
|
||||||
|
|
||||||
# once caching is supported, we can pass id(data) as cache id.
|
# once caching is supported, we can pass id(data) as cache id.
|
||||||
args = {
|
args = make_jcargs(
|
||||||
"type": 0,
|
type=1 if predict_type == "margin" else 0,
|
||||||
"training": False,
|
training=False,
|
||||||
"iteration_begin": iteration_range[0],
|
iteration_begin=iteration_range[0],
|
||||||
"iteration_end": iteration_range[1],
|
iteration_end=iteration_range[1],
|
||||||
"missing": missing,
|
missing=missing,
|
||||||
"strict_shape": strict_shape,
|
strict_shape=strict_shape,
|
||||||
"cache_id": 0,
|
cache_id=0,
|
||||||
}
|
)
|
||||||
if predict_type == "margin":
|
|
||||||
args["type"] = 1
|
|
||||||
shape = ctypes.POINTER(c_bst_ulong)()
|
shape = ctypes.POINTER(c_bst_ulong)()
|
||||||
dims = c_bst_ulong()
|
dims = c_bst_ulong()
|
||||||
|
|
||||||
@ -2243,6 +2241,29 @@ class Booster:
|
|||||||
proxy = None
|
proxy = None
|
||||||
p_handle = ctypes.c_void_p()
|
p_handle = ctypes.c_void_p()
|
||||||
assert proxy is None or isinstance(proxy, _ProxyDMatrix)
|
assert proxy is None or isinstance(proxy, _ProxyDMatrix)
|
||||||
|
|
||||||
|
from .data import (
|
||||||
|
_array_interface,
|
||||||
|
_is_cudf_df,
|
||||||
|
_is_cupy_array,
|
||||||
|
_is_list,
|
||||||
|
_is_pandas_df,
|
||||||
|
_is_pandas_series,
|
||||||
|
_is_tuple,
|
||||||
|
_transform_pandas_df,
|
||||||
|
)
|
||||||
|
|
||||||
|
enable_categorical = True
|
||||||
|
if _is_pandas_series(data):
|
||||||
|
import pandas as pd
|
||||||
|
data = pd.DataFrame(data)
|
||||||
|
if _is_pandas_df(data):
|
||||||
|
data, fns, _ = _transform_pandas_df(data, enable_categorical)
|
||||||
|
if validate_features:
|
||||||
|
self._validate_features(fns)
|
||||||
|
if _is_list(data) or _is_tuple(data):
|
||||||
|
data = np.array(data)
|
||||||
|
|
||||||
if validate_features:
|
if validate_features:
|
||||||
if not hasattr(data, "shape"):
|
if not hasattr(data, "shape"):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
@ -2254,20 +2275,6 @@ class Booster:
|
|||||||
f"got {data.shape[1]}"
|
f"got {data.shape[1]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
from .data import (
|
|
||||||
_array_interface,
|
|
||||||
_is_cudf_df,
|
|
||||||
_is_cupy_array,
|
|
||||||
_is_pandas_df,
|
|
||||||
_transform_pandas_df,
|
|
||||||
)
|
|
||||||
|
|
||||||
enable_categorical = True
|
|
||||||
if _is_pandas_df(data):
|
|
||||||
data, fns, _ = _transform_pandas_df(data, enable_categorical)
|
|
||||||
if validate_features:
|
|
||||||
self._validate_features(fns)
|
|
||||||
|
|
||||||
if isinstance(data, np.ndarray):
|
if isinstance(data, np.ndarray):
|
||||||
from .data import _ensure_np_dtype
|
from .data import _ensure_np_dtype
|
||||||
|
|
||||||
@ -2276,7 +2283,7 @@ class Booster:
|
|||||||
_LIB.XGBoosterPredictFromDense(
|
_LIB.XGBoosterPredictFromDense(
|
||||||
self.handle,
|
self.handle,
|
||||||
_array_interface(data),
|
_array_interface(data),
|
||||||
from_pystr_to_cstr(json.dumps(args)),
|
args,
|
||||||
p_handle,
|
p_handle,
|
||||||
ctypes.byref(shape),
|
ctypes.byref(shape),
|
||||||
ctypes.byref(dims),
|
ctypes.byref(dims),
|
||||||
@ -2293,7 +2300,7 @@ class Booster:
|
|||||||
_array_interface(csr.indices),
|
_array_interface(csr.indices),
|
||||||
_array_interface(csr.data),
|
_array_interface(csr.data),
|
||||||
c_bst_ulong(csr.shape[1]),
|
c_bst_ulong(csr.shape[1]),
|
||||||
from_pystr_to_cstr(json.dumps(args)),
|
args,
|
||||||
p_handle,
|
p_handle,
|
||||||
ctypes.byref(shape),
|
ctypes.byref(shape),
|
||||||
ctypes.byref(dims),
|
ctypes.byref(dims),
|
||||||
@ -2310,7 +2317,7 @@ class Booster:
|
|||||||
_LIB.XGBoosterPredictFromCudaArray(
|
_LIB.XGBoosterPredictFromCudaArray(
|
||||||
self.handle,
|
self.handle,
|
||||||
interface_str,
|
interface_str,
|
||||||
from_pystr_to_cstr(json.dumps(args)),
|
args,
|
||||||
p_handle,
|
p_handle,
|
||||||
ctypes.byref(shape),
|
ctypes.byref(shape),
|
||||||
ctypes.byref(dims),
|
ctypes.byref(dims),
|
||||||
@ -2331,7 +2338,7 @@ class Booster:
|
|||||||
_LIB.XGBoosterPredictFromCudaColumnar(
|
_LIB.XGBoosterPredictFromCudaColumnar(
|
||||||
self.handle,
|
self.handle,
|
||||||
interfaces_str,
|
interfaces_str,
|
||||||
from_pystr_to_cstr(json.dumps(args)),
|
args,
|
||||||
p_handle,
|
p_handle,
|
||||||
ctypes.byref(shape),
|
ctypes.byref(shape),
|
||||||
ctypes.byref(dims),
|
ctypes.byref(dims),
|
||||||
|
|||||||
@ -958,12 +958,12 @@ def dispatch_data_backend(
|
|||||||
return _from_list(data, missing, threads, feature_names, feature_types)
|
return _from_list(data, missing, threads, feature_names, feature_types)
|
||||||
if _is_tuple(data):
|
if _is_tuple(data):
|
||||||
return _from_tuple(data, missing, threads, feature_names, feature_types)
|
return _from_tuple(data, missing, threads, feature_names, feature_types)
|
||||||
if _is_pandas_df(data):
|
|
||||||
return _from_pandas_df(data, enable_categorical, missing, threads,
|
|
||||||
feature_names, feature_types)
|
|
||||||
if _is_pandas_series(data):
|
if _is_pandas_series(data):
|
||||||
return _from_pandas_series(
|
import pandas as pd
|
||||||
data, missing, threads, enable_categorical, feature_names, feature_types
|
data = pd.DataFrame(data)
|
||||||
|
if _is_pandas_df(data):
|
||||||
|
return _from_pandas_df(
|
||||||
|
data, enable_categorical, missing, threads, feature_names, feature_types
|
||||||
)
|
)
|
||||||
if _is_cudf_df(data) or _is_cudf_ser(data):
|
if _is_cudf_df(data) or _is_cudf_ser(data):
|
||||||
return _from_cudf_df(
|
return _from_cudf_df(
|
||||||
@ -1205,6 +1205,9 @@ def _proxy_transform(
|
|||||||
return data, None, feature_names, feature_types
|
return data, None, feature_names, feature_types
|
||||||
if _is_scipy_csr(data):
|
if _is_scipy_csr(data):
|
||||||
return data, None, feature_names, feature_types
|
return data, None, feature_names, feature_types
|
||||||
|
if _is_pandas_series(data):
|
||||||
|
import pandas as pd
|
||||||
|
data = pd.DataFrame(data)
|
||||||
if _is_pandas_df(data):
|
if _is_pandas_df(data):
|
||||||
arr, feature_names, feature_types = _transform_pandas_df(
|
arr, feature_names, feature_types = _transform_pandas_df(
|
||||||
data, enable_categorical, feature_names, feature_types
|
data, enable_categorical, feature_names, feature_types
|
||||||
|
|||||||
@ -40,6 +40,7 @@ def np_dtypes(
|
|||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
X = np.array(orig, dtype=dtype)
|
X = np.array(orig, dtype=dtype)
|
||||||
yield orig, X
|
yield orig, X
|
||||||
|
yield orig.tolist(), X.tolist()
|
||||||
|
|
||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
X = np.array(orig, dtype=dtype)
|
X = np.array(orig, dtype=dtype)
|
||||||
@ -101,6 +102,11 @@ def pd_dtypes() -> Generator:
|
|||||||
{"f0": [1.0, 2.0, Null, 3.0], "f1": [3.0, 2.0, Null, 1.0]}, dtype=dtype
|
{"f0": [1.0, 2.0, Null, 3.0], "f1": [3.0, 2.0, Null, 1.0]}, dtype=dtype
|
||||||
)
|
)
|
||||||
yield orig, df
|
yield orig, df
|
||||||
|
ser_orig = orig["f0"]
|
||||||
|
ser = df["f0"]
|
||||||
|
assert isinstance(ser, pd.Series)
|
||||||
|
assert isinstance(ser_orig, pd.Series)
|
||||||
|
yield ser_orig, ser
|
||||||
|
|
||||||
# Categorical
|
# Categorical
|
||||||
orig = orig.astype("category")
|
orig = orig.astype("category")
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
from scipy import sparse
|
from scipy import sparse
|
||||||
from xgboost.testing.data import np_dtypes
|
from xgboost.testing.data import np_dtypes, pd_dtypes
|
||||||
from xgboost.testing.shared import validate_leaf_output
|
from xgboost.testing.shared import validate_leaf_output
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
@ -231,6 +231,7 @@ class TestInplacePredict:
|
|||||||
from_dmatrix = booster.predict(dtrain)
|
from_dmatrix = booster.predict(dtrain)
|
||||||
np.testing.assert_allclose(from_dmatrix, from_inplace)
|
np.testing.assert_allclose(from_dmatrix, from_inplace)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_pandas())
|
||||||
def test_dtypes(self) -> None:
|
def test_dtypes(self) -> None:
|
||||||
for orig, x in np_dtypes(self.rows, self.cols):
|
for orig, x in np_dtypes(self.rows, self.cols):
|
||||||
predt_orig = self.booster.inplace_predict(orig)
|
predt_orig = self.booster.inplace_predict(orig)
|
||||||
@ -246,3 +247,17 @@ class TestInplacePredict:
|
|||||||
X: np.ndarray = np.array(orig, dtype=dtype)
|
X: np.ndarray = np.array(orig, dtype=dtype)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
self.booster.inplace_predict(X)
|
self.booster.inplace_predict(X)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_pandas())
|
||||||
|
def test_pd_dtypes(self) -> None:
|
||||||
|
from pandas.api.types import is_bool_dtype
|
||||||
|
for orig, x in pd_dtypes():
|
||||||
|
dtypes = orig.dtypes if isinstance(orig, pd.DataFrame) else [orig.dtypes]
|
||||||
|
if isinstance(orig, pd.DataFrame) and is_bool_dtype(dtypes[0]):
|
||||||
|
continue
|
||||||
|
y = np.arange(x.shape[0])
|
||||||
|
Xy = xgb.DMatrix(orig, y, enable_categorical=True)
|
||||||
|
booster = xgb.train({"tree_method": "hist"}, Xy, num_boost_round=1)
|
||||||
|
predt_orig = booster.inplace_predict(orig)
|
||||||
|
predt = booster.inplace_predict(x)
|
||||||
|
np.testing.assert_allclose(predt, predt_orig)
|
||||||
|
|||||||
@ -298,22 +298,29 @@ class TestPandas:
|
|||||||
assert 'auc' not in cv.columns[0]
|
assert 'auc' not in cv.columns[0]
|
||||||
assert 'error' in cv.columns[0]
|
assert 'error' in cv.columns[0]
|
||||||
|
|
||||||
def test_nullable_type(self) -> None:
|
@pytest.mark.parametrize("DMatrixT", [xgb.DMatrix, xgb.QuantileDMatrix])
|
||||||
|
def test_nullable_type(self, DMatrixT) -> None:
|
||||||
from pandas.api.types import is_categorical
|
from pandas.api.types import is_categorical
|
||||||
|
|
||||||
for DMatrixT in (xgb.DMatrix, xgb.QuantileDMatrix):
|
|
||||||
for orig, df in pd_dtypes():
|
for orig, df in pd_dtypes():
|
||||||
|
if hasattr(df.dtypes, "__iter__"):
|
||||||
enable_categorical = any(is_categorical for dtype in df.dtypes)
|
enable_categorical = any(is_categorical for dtype in df.dtypes)
|
||||||
|
else:
|
||||||
|
# series
|
||||||
|
enable_categorical = is_categorical(df.dtype)
|
||||||
|
|
||||||
m_orig = DMatrixT(orig, enable_categorical=enable_categorical)
|
m_orig = DMatrixT(orig, enable_categorical=enable_categorical)
|
||||||
# extension types
|
# extension types
|
||||||
m_etype = DMatrixT(df, enable_categorical=enable_categorical)
|
m_etype = DMatrixT(df, enable_categorical=enable_categorical)
|
||||||
# different from pd.BooleanDtype(), None is converted to False with bool
|
# different from pd.BooleanDtype(), None is converted to False with bool
|
||||||
if any(dtype == "bool" for dtype in orig.dtypes):
|
if hasattr(orig.dtypes, "__iter__") and any(
|
||||||
|
dtype == "bool" for dtype in orig.dtypes
|
||||||
|
):
|
||||||
assert not tm.predictor_equal(m_orig, m_etype)
|
assert not tm.predictor_equal(m_orig, m_etype)
|
||||||
else:
|
else:
|
||||||
assert tm.predictor_equal(m_orig, m_etype)
|
assert tm.predictor_equal(m_orig, m_etype)
|
||||||
|
|
||||||
|
if isinstance(df, pd.DataFrame):
|
||||||
f0 = df["f0"]
|
f0 = df["f0"]
|
||||||
with pytest.raises(ValueError, match="Label contains NaN"):
|
with pytest.raises(ValueError, match="Label contains NaN"):
|
||||||
xgb.DMatrix(df, f0, enable_categorical=enable_categorical)
|
xgb.DMatrix(df, f0, enable_categorical=enable_categorical)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user