Handle np integer in model slice and prediction. (#10007)

This commit is contained in:
Jiaming Yuan 2024-01-26 04:58:48 +08:00 committed by GitHub
parent a76d6c6131
commit 65d7bf2dfe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 75 additions and 49 deletions

View File

@ -36,6 +36,11 @@ PandasDType = Any # real type is pandas.core.dtypes.base.ExtensionDtype
FloatCompatible = Union[float, np.float32, np.float64]
# typing.SupportsInt is not suitable here since floating point values are convertible to
# integers as well.
Integer = Union[int, np.integer]
IterationRange = Tuple[Integer, Integer]
# callables
FPreProcCallable = Callable

View File

@ -48,6 +48,8 @@ from ._typing import (
FeatureInfo,
FeatureNames,
FeatureTypes,
Integer,
IterationRange,
ModelIn,
NumpyOrCupy,
TransformedData,
@ -1812,19 +1814,25 @@ class Booster:
state["handle"] = handle
self.__dict__.update(state)
def __getitem__(self, val: Union[int, tuple, slice]) -> "Booster":
def __getitem__(self, val: Union[Integer, tuple, slice]) -> "Booster":
"""Get a slice of the tree-based model.
.. versionadded:: 1.3.0
"""
if isinstance(val, int):
val = slice(val, val + 1)
# convert to slice for all other types
if isinstance(val, (np.integer, int)):
val = slice(int(val), int(val + 1))
if isinstance(val, type(Ellipsis)):
val = slice(0, 0)
if isinstance(val, tuple):
raise ValueError("Only supports slicing through 1 dimension.")
# All supported types are now slice
# FIXME(jiamingy): Use `types.EllipsisType` once Python 3.10 is used.
if not isinstance(val, slice):
msg = _expect((int, slice), type(val))
msg = _expect((int, slice, np.integer, type(Ellipsis)), type(val))
raise TypeError(msg)
if isinstance(val.start, type(Ellipsis)) or val.start is None:
start = 0
else:
@ -2246,12 +2254,13 @@ class Booster:
pred_interactions: bool = False,
validate_features: bool = True,
training: bool = False,
iteration_range: Tuple[int, int] = (0, 0),
iteration_range: IterationRange = (0, 0),
strict_shape: bool = False,
) -> np.ndarray:
"""Predict with data. The full model will be used unless `iteration_range` is specified,
meaning user have to either slice the model or use the ``best_iteration``
attribute to get prediction from best model returned from early stopping.
"""Predict with data. The full model will be used unless `iteration_range` is
specified, meaning user have to either slice the model or use the
``best_iteration`` attribute to get prediction from best model returned from
early stopping.
.. note::
@ -2336,8 +2345,8 @@ class Booster:
args = {
"type": 0,
"training": training,
"iteration_begin": iteration_range[0],
"iteration_end": iteration_range[1],
"iteration_begin": int(iteration_range[0]),
"iteration_end": int(iteration_range[1]),
"strict_shape": strict_shape,
}
@ -2373,7 +2382,7 @@ class Booster:
def inplace_predict(
self,
data: DataType,
iteration_range: Tuple[int, int] = (0, 0),
iteration_range: IterationRange = (0, 0),
predict_type: str = "value",
missing: float = np.nan,
validate_features: bool = True,
@ -2439,8 +2448,8 @@ class Booster:
args = make_jcargs(
type=1 if predict_type == "margin" else 0,
training=False,
iteration_begin=iteration_range[0],
iteration_end=iteration_range[1],
iteration_begin=int(iteration_range[0]),
iteration_end=int(iteration_range[1]),
missing=missing,
strict_shape=strict_shape,
cache_id=0,

View File

@ -61,7 +61,7 @@ from typing import (
import numpy
from xgboost import collective, config
from xgboost._typing import _T, FeatureNames, FeatureTypes
from xgboost._typing import _T, FeatureNames, FeatureTypes, IterationRange
from xgboost.callback import TrainingCallback
from xgboost.compat import DataFrame, LazyLoader, concat, lazy_isinstance
from xgboost.core import (
@ -1263,7 +1263,7 @@ async def _predict_async(
approx_contribs: bool,
pred_interactions: bool,
validate_features: bool,
iteration_range: Tuple[int, int],
iteration_range: IterationRange,
strict_shape: bool,
) -> _DaskCollection:
_booster = await _get_model_future(client, model)
@ -1410,7 +1410,7 @@ def predict( # pylint: disable=unused-argument
approx_contribs: bool = False,
pred_interactions: bool = False,
validate_features: bool = True,
iteration_range: Tuple[int, int] = (0, 0),
iteration_range: IterationRange = (0, 0),
strict_shape: bool = False,
) -> Any:
"""Run prediction with a trained booster.
@ -1458,7 +1458,7 @@ async def _inplace_predict_async( # pylint: disable=too-many-branches
global_config: Dict[str, Any],
model: Union[Booster, Dict, "distributed.Future"],
data: _DataT,
iteration_range: Tuple[int, int],
iteration_range: IterationRange,
predict_type: str,
missing: float,
validate_features: bool,
@ -1516,7 +1516,7 @@ def inplace_predict( # pylint: disable=unused-argument
client: Optional["distributed.Client"],
model: Union[TrainReturnT, Booster, "distributed.Future"],
data: _DataT,
iteration_range: Tuple[int, int] = (0, 0),
iteration_range: IterationRange = (0, 0),
predict_type: str = "value",
missing: float = numpy.nan,
validate_features: bool = True,
@ -1624,7 +1624,7 @@ class DaskScikitLearnBase(XGBModel):
output_margin: bool,
validate_features: bool,
base_margin: Optional[_DaskCollection],
iteration_range: Optional[Tuple[int, int]],
iteration_range: Optional[IterationRange],
) -> Any:
iteration_range = self._get_iteration_range(iteration_range)
if self._can_use_inplace_predict():
@ -1664,7 +1664,7 @@ class DaskScikitLearnBase(XGBModel):
output_margin: bool = False,
validate_features: bool = True,
base_margin: Optional[_DaskCollection] = None,
iteration_range: Optional[Tuple[int, int]] = None,
iteration_range: Optional[IterationRange] = None,
) -> Any:
_assert_dask_support()
return self.client.sync(
@ -1679,7 +1679,7 @@ class DaskScikitLearnBase(XGBModel):
async def _apply_async(
self,
X: _DataT,
iteration_range: Optional[Tuple[int, int]] = None,
iteration_range: Optional[IterationRange] = None,
) -> Any:
iteration_range = self._get_iteration_range(iteration_range)
test_dmatrix = await DaskDMatrix(
@ -1700,7 +1700,7 @@ class DaskScikitLearnBase(XGBModel):
def apply(
self,
X: _DataT,
iteration_range: Optional[Tuple[int, int]] = None,
iteration_range: Optional[IterationRange] = None,
) -> Any:
_assert_dask_support()
return self.client.sync(self._apply_async, X, iteration_range=iteration_range)
@ -1962,7 +1962,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
X: _DataT,
validate_features: bool,
base_margin: Optional[_DaskCollection],
iteration_range: Optional[Tuple[int, int]],
iteration_range: Optional[IterationRange],
) -> _DaskCollection:
if self.objective == "multi:softmax":
raise ValueError(
@ -1987,7 +1987,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
X: _DaskCollection,
validate_features: bool = True,
base_margin: Optional[_DaskCollection] = None,
iteration_range: Optional[Tuple[int, int]] = None,
iteration_range: Optional[IterationRange] = None,
) -> Any:
_assert_dask_support()
return self._client_sync(
@ -2006,7 +2006,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
output_margin: bool,
validate_features: bool,
base_margin: Optional[_DaskCollection],
iteration_range: Optional[Tuple[int, int]],
iteration_range: Optional[IterationRange],
) -> _DaskCollection:
pred_probs = await super()._predict_async(
data, output_margin, validate_features, base_margin, iteration_range

View File

@ -22,7 +22,7 @@ from typing import (
import numpy as np
from scipy.special import softmax
from ._typing import ArrayLike, FeatureNames, FeatureTypes, ModelIn
from ._typing import ArrayLike, FeatureNames, FeatureTypes, IterationRange, ModelIn
from .callback import TrainingCallback
# Do not use class names on scikit-learn directly. Re-define the classes on
@ -1039,8 +1039,8 @@ class XGBModel(XGBModelBase):
return False
def _get_iteration_range(
self, iteration_range: Optional[Tuple[int, int]]
) -> Tuple[int, int]:
self, iteration_range: Optional[IterationRange]
) -> IterationRange:
if iteration_range is None or iteration_range[1] == 0:
# Use best_iteration if defined.
try:
@ -1057,7 +1057,7 @@ class XGBModel(XGBModelBase):
output_margin: bool = False,
validate_features: bool = True,
base_margin: Optional[ArrayLike] = None,
iteration_range: Optional[Tuple[int, int]] = None,
iteration_range: Optional[IterationRange] = None,
) -> ArrayLike:
"""Predict with `X`. If the model is trained with early stopping, then
:py:attr:`best_iteration` is used automatically. The estimator uses
@ -1129,7 +1129,7 @@ class XGBModel(XGBModelBase):
def apply(
self,
X: ArrayLike,
iteration_range: Optional[Tuple[int, int]] = None,
iteration_range: Optional[IterationRange] = None,
) -> np.ndarray:
"""Return the predicted leaf every tree for each sample. If the model is trained
with early stopping, then :py:attr:`best_iteration` is used automatically.
@ -1465,7 +1465,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
output_margin: bool = False,
validate_features: bool = True,
base_margin: Optional[ArrayLike] = None,
iteration_range: Optional[Tuple[int, int]] = None,
iteration_range: Optional[IterationRange] = None,
) -> ArrayLike:
with config_context(verbosity=self.verbosity):
class_probs = super().predict(
@ -1500,7 +1500,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
X: ArrayLike,
validate_features: bool = True,
base_margin: Optional[ArrayLike] = None,
iteration_range: Optional[Tuple[int, int]] = None,
iteration_range: Optional[IterationRange] = None,
) -> np.ndarray:
"""Predict the probability of each `X` example being of a given class. If the
model is trained with early stopping, then :py:attr:`best_iteration` is used
@ -1942,7 +1942,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
output_margin: bool = False,
validate_features: bool = True,
base_margin: Optional[ArrayLike] = None,
iteration_range: Optional[Tuple[int, int]] = None,
iteration_range: Optional[IterationRange] = None,
) -> ArrayLike:
X, _ = _get_qid(X, None)
return super().predict(
@ -1956,7 +1956,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
def apply(
self,
X: ArrayLike,
iteration_range: Optional[Tuple[int, int]] = None,
iteration_range: Optional[IterationRange] = None,
) -> ArrayLike:
X, _ = _get_qid(X, None)
return super().apply(X, iteration_range)

View File

@ -7,6 +7,7 @@ import pytest
import xgboost as xgb
from xgboost import testing as tm
from xgboost.core import Integer
from xgboost.testing.updater import ResetStrategy
dpath = tm.data_dir(__file__)
@ -97,15 +98,15 @@ class TestModels:
def test_boost_from_prediction(self):
# Re-construct dtrain here to avoid modification
margined, _ = tm.load_agaricus(__file__)
bst = xgb.train({'tree_method': 'hist'}, margined, 1)
bst = xgb.train({"tree_method": "hist"}, margined, 1)
predt_0 = bst.predict(margined, output_margin=True)
margined.set_base_margin(predt_0)
bst = xgb.train({'tree_method': 'hist'}, margined, 1)
bst = xgb.train({"tree_method": "hist"}, margined, 1)
predt_1 = bst.predict(margined)
assert np.any(np.abs(predt_1 - predt_0) > 1e-6)
dtrain, _ = tm.load_agaricus(__file__)
bst = xgb.train({'tree_method': 'hist'}, dtrain, 2)
bst = xgb.train({"tree_method": "hist"}, dtrain, 2)
predt_2 = bst.predict(dtrain)
assert np.all(np.abs(predt_2 - predt_1) < 1e-6)
@ -331,10 +332,15 @@ class TestModels:
dtrain: xgb.DMatrix,
num_parallel_tree: int,
num_classes: int,
num_boost_round: int
num_boost_round: int,
use_np_type: bool,
):
beg = 3
end = 7
if use_np_type:
end: Integer = np.int32(7)
else:
end = 7
sliced: xgb.Booster = booster[beg:end]
assert sliced.feature_types == booster.feature_types
@ -345,7 +351,7 @@ class TestModels:
sliced = booster[beg:end:2]
assert sliced_trees == len(sliced.get_dump())
sliced = booster[beg: ...]
sliced = booster[beg:]
sliced_trees = (num_boost_round - beg) * num_parallel_tree * num_classes
assert sliced_trees == len(sliced.get_dump())
@ -357,7 +363,7 @@ class TestModels:
sliced_trees = end * num_parallel_tree * num_classes
assert sliced_trees == len(sliced.get_dump())
sliced = booster[...: end]
sliced = booster[: end]
sliced_trees = end * num_parallel_tree * num_classes
assert sliced_trees == len(sliced.get_dump())
@ -383,14 +389,14 @@ class TestModels:
assert len(trees) == num_boost_round
with pytest.raises(TypeError):
booster["wrong type"]
booster["wrong type"] # type: ignore
with pytest.raises(IndexError):
booster[: num_boost_round + 1]
with pytest.raises(ValueError):
booster[1, 2] # too many dims
# setitem is not implemented as model is immutable during slicing.
with pytest.raises(TypeError):
booster[...: end] = booster
booster[:end] = booster # type: ignore
sliced_0 = booster[1:3]
np.testing.assert_allclose(
@ -446,15 +452,21 @@ class TestModels:
assert len(booster.get_dump()) == total_trees
self.run_slice(booster, dtrain, num_parallel_tree, num_classes, num_boost_round)
self.run_slice(
booster, dtrain, num_parallel_tree, num_classes, num_boost_round, False
)
bytesarray = booster.save_raw(raw_format="ubj")
booster = xgb.Booster(model_file=bytesarray)
self.run_slice(booster, dtrain, num_parallel_tree, num_classes, num_boost_round)
self.run_slice(
booster, dtrain, num_parallel_tree, num_classes, num_boost_round, False
)
bytesarray = booster.save_raw(raw_format="deprecated")
booster = xgb.Booster(model_file=bytesarray)
self.run_slice(booster, dtrain, num_parallel_tree, num_classes, num_boost_round)
self.run_slice(
booster, dtrain, num_parallel_tree, num_classes, num_boost_round, True
)
def test_slice_multi(self) -> None:
from sklearn.datasets import make_classification
@ -479,7 +491,7 @@ class TestModels:
},
num_boost_round=num_boost_round,
dtrain=Xy,
callbacks=[ResetStrategy()]
callbacks=[ResetStrategy()],
)
sliced = [t for t in booster]
assert len(sliced) == 16

View File

@ -61,7 +61,7 @@ def run_predict_leaf(device: str) -> np.ndarray:
validate_leaf_output(leaf, num_parallel_tree)
n_iters = 2
n_iters = np.int32(2)
sliced = booster.predict(
m,
pred_leaf=True,

View File

@ -440,7 +440,7 @@ def test_regression():
preds = xgb_model.predict(X[test_index])
# test other params in XGBRegressor().fit
preds2 = xgb_model.predict(
X[test_index], output_margin=True, iteration_range=(0, 3)
X[test_index], output_margin=True, iteration_range=(0, np.int16(3))
)
preds3 = xgb_model.predict(
X[test_index], output_margin=True, iteration_range=None