Handle np integer in model slice and prediction. (#10007)
This commit is contained in:
parent
a76d6c6131
commit
65d7bf2dfe
@ -36,6 +36,11 @@ PandasDType = Any # real type is pandas.core.dtypes.base.ExtensionDtype
|
||||
|
||||
FloatCompatible = Union[float, np.float32, np.float64]
|
||||
|
||||
# typing.SupportsInt is not suitable here since floating point values are convertible to
|
||||
# integers as well.
|
||||
Integer = Union[int, np.integer]
|
||||
IterationRange = Tuple[Integer, Integer]
|
||||
|
||||
# callables
|
||||
FPreProcCallable = Callable
|
||||
|
||||
|
||||
@ -48,6 +48,8 @@ from ._typing import (
|
||||
FeatureInfo,
|
||||
FeatureNames,
|
||||
FeatureTypes,
|
||||
Integer,
|
||||
IterationRange,
|
||||
ModelIn,
|
||||
NumpyOrCupy,
|
||||
TransformedData,
|
||||
@ -1812,19 +1814,25 @@ class Booster:
|
||||
state["handle"] = handle
|
||||
self.__dict__.update(state)
|
||||
|
||||
def __getitem__(self, val: Union[int, tuple, slice]) -> "Booster":
|
||||
def __getitem__(self, val: Union[Integer, tuple, slice]) -> "Booster":
|
||||
"""Get a slice of the tree-based model.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
"""
|
||||
if isinstance(val, int):
|
||||
val = slice(val, val + 1)
|
||||
# convert to slice for all other types
|
||||
if isinstance(val, (np.integer, int)):
|
||||
val = slice(int(val), int(val + 1))
|
||||
if isinstance(val, type(Ellipsis)):
|
||||
val = slice(0, 0)
|
||||
if isinstance(val, tuple):
|
||||
raise ValueError("Only supports slicing through 1 dimension.")
|
||||
# All supported types are now slice
|
||||
# FIXME(jiamingy): Use `types.EllipsisType` once Python 3.10 is used.
|
||||
if not isinstance(val, slice):
|
||||
msg = _expect((int, slice), type(val))
|
||||
msg = _expect((int, slice, np.integer, type(Ellipsis)), type(val))
|
||||
raise TypeError(msg)
|
||||
|
||||
if isinstance(val.start, type(Ellipsis)) or val.start is None:
|
||||
start = 0
|
||||
else:
|
||||
@ -2246,12 +2254,13 @@ class Booster:
|
||||
pred_interactions: bool = False,
|
||||
validate_features: bool = True,
|
||||
training: bool = False,
|
||||
iteration_range: Tuple[int, int] = (0, 0),
|
||||
iteration_range: IterationRange = (0, 0),
|
||||
strict_shape: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""Predict with data. The full model will be used unless `iteration_range` is specified,
|
||||
meaning user have to either slice the model or use the ``best_iteration``
|
||||
attribute to get prediction from best model returned from early stopping.
|
||||
"""Predict with data. The full model will be used unless `iteration_range` is
|
||||
specified, meaning user have to either slice the model or use the
|
||||
``best_iteration`` attribute to get prediction from best model returned from
|
||||
early stopping.
|
||||
|
||||
.. note::
|
||||
|
||||
@ -2336,8 +2345,8 @@ class Booster:
|
||||
args = {
|
||||
"type": 0,
|
||||
"training": training,
|
||||
"iteration_begin": iteration_range[0],
|
||||
"iteration_end": iteration_range[1],
|
||||
"iteration_begin": int(iteration_range[0]),
|
||||
"iteration_end": int(iteration_range[1]),
|
||||
"strict_shape": strict_shape,
|
||||
}
|
||||
|
||||
@ -2373,7 +2382,7 @@ class Booster:
|
||||
def inplace_predict(
|
||||
self,
|
||||
data: DataType,
|
||||
iteration_range: Tuple[int, int] = (0, 0),
|
||||
iteration_range: IterationRange = (0, 0),
|
||||
predict_type: str = "value",
|
||||
missing: float = np.nan,
|
||||
validate_features: bool = True,
|
||||
@ -2439,8 +2448,8 @@ class Booster:
|
||||
args = make_jcargs(
|
||||
type=1 if predict_type == "margin" else 0,
|
||||
training=False,
|
||||
iteration_begin=iteration_range[0],
|
||||
iteration_end=iteration_range[1],
|
||||
iteration_begin=int(iteration_range[0]),
|
||||
iteration_end=int(iteration_range[1]),
|
||||
missing=missing,
|
||||
strict_shape=strict_shape,
|
||||
cache_id=0,
|
||||
|
||||
@ -61,7 +61,7 @@ from typing import (
|
||||
import numpy
|
||||
|
||||
from xgboost import collective, config
|
||||
from xgboost._typing import _T, FeatureNames, FeatureTypes
|
||||
from xgboost._typing import _T, FeatureNames, FeatureTypes, IterationRange
|
||||
from xgboost.callback import TrainingCallback
|
||||
from xgboost.compat import DataFrame, LazyLoader, concat, lazy_isinstance
|
||||
from xgboost.core import (
|
||||
@ -1263,7 +1263,7 @@ async def _predict_async(
|
||||
approx_contribs: bool,
|
||||
pred_interactions: bool,
|
||||
validate_features: bool,
|
||||
iteration_range: Tuple[int, int],
|
||||
iteration_range: IterationRange,
|
||||
strict_shape: bool,
|
||||
) -> _DaskCollection:
|
||||
_booster = await _get_model_future(client, model)
|
||||
@ -1410,7 +1410,7 @@ def predict( # pylint: disable=unused-argument
|
||||
approx_contribs: bool = False,
|
||||
pred_interactions: bool = False,
|
||||
validate_features: bool = True,
|
||||
iteration_range: Tuple[int, int] = (0, 0),
|
||||
iteration_range: IterationRange = (0, 0),
|
||||
strict_shape: bool = False,
|
||||
) -> Any:
|
||||
"""Run prediction with a trained booster.
|
||||
@ -1458,7 +1458,7 @@ async def _inplace_predict_async( # pylint: disable=too-many-branches
|
||||
global_config: Dict[str, Any],
|
||||
model: Union[Booster, Dict, "distributed.Future"],
|
||||
data: _DataT,
|
||||
iteration_range: Tuple[int, int],
|
||||
iteration_range: IterationRange,
|
||||
predict_type: str,
|
||||
missing: float,
|
||||
validate_features: bool,
|
||||
@ -1516,7 +1516,7 @@ def inplace_predict( # pylint: disable=unused-argument
|
||||
client: Optional["distributed.Client"],
|
||||
model: Union[TrainReturnT, Booster, "distributed.Future"],
|
||||
data: _DataT,
|
||||
iteration_range: Tuple[int, int] = (0, 0),
|
||||
iteration_range: IterationRange = (0, 0),
|
||||
predict_type: str = "value",
|
||||
missing: float = numpy.nan,
|
||||
validate_features: bool = True,
|
||||
@ -1624,7 +1624,7 @@ class DaskScikitLearnBase(XGBModel):
|
||||
output_margin: bool,
|
||||
validate_features: bool,
|
||||
base_margin: Optional[_DaskCollection],
|
||||
iteration_range: Optional[Tuple[int, int]],
|
||||
iteration_range: Optional[IterationRange],
|
||||
) -> Any:
|
||||
iteration_range = self._get_iteration_range(iteration_range)
|
||||
if self._can_use_inplace_predict():
|
||||
@ -1664,7 +1664,7 @@ class DaskScikitLearnBase(XGBModel):
|
||||
output_margin: bool = False,
|
||||
validate_features: bool = True,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
iteration_range: Optional[IterationRange] = None,
|
||||
) -> Any:
|
||||
_assert_dask_support()
|
||||
return self.client.sync(
|
||||
@ -1679,7 +1679,7 @@ class DaskScikitLearnBase(XGBModel):
|
||||
async def _apply_async(
|
||||
self,
|
||||
X: _DataT,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
iteration_range: Optional[IterationRange] = None,
|
||||
) -> Any:
|
||||
iteration_range = self._get_iteration_range(iteration_range)
|
||||
test_dmatrix = await DaskDMatrix(
|
||||
@ -1700,7 +1700,7 @@ class DaskScikitLearnBase(XGBModel):
|
||||
def apply(
|
||||
self,
|
||||
X: _DataT,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
iteration_range: Optional[IterationRange] = None,
|
||||
) -> Any:
|
||||
_assert_dask_support()
|
||||
return self.client.sync(self._apply_async, X, iteration_range=iteration_range)
|
||||
@ -1962,7 +1962,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
X: _DataT,
|
||||
validate_features: bool,
|
||||
base_margin: Optional[_DaskCollection],
|
||||
iteration_range: Optional[Tuple[int, int]],
|
||||
iteration_range: Optional[IterationRange],
|
||||
) -> _DaskCollection:
|
||||
if self.objective == "multi:softmax":
|
||||
raise ValueError(
|
||||
@ -1987,7 +1987,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
X: _DaskCollection,
|
||||
validate_features: bool = True,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
iteration_range: Optional[IterationRange] = None,
|
||||
) -> Any:
|
||||
_assert_dask_support()
|
||||
return self._client_sync(
|
||||
@ -2006,7 +2006,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
output_margin: bool,
|
||||
validate_features: bool,
|
||||
base_margin: Optional[_DaskCollection],
|
||||
iteration_range: Optional[Tuple[int, int]],
|
||||
iteration_range: Optional[IterationRange],
|
||||
) -> _DaskCollection:
|
||||
pred_probs = await super()._predict_async(
|
||||
data, output_margin, validate_features, base_margin, iteration_range
|
||||
|
||||
@ -22,7 +22,7 @@ from typing import (
|
||||
import numpy as np
|
||||
from scipy.special import softmax
|
||||
|
||||
from ._typing import ArrayLike, FeatureNames, FeatureTypes, ModelIn
|
||||
from ._typing import ArrayLike, FeatureNames, FeatureTypes, IterationRange, ModelIn
|
||||
from .callback import TrainingCallback
|
||||
|
||||
# Do not use class names on scikit-learn directly. Re-define the classes on
|
||||
@ -1039,8 +1039,8 @@ class XGBModel(XGBModelBase):
|
||||
return False
|
||||
|
||||
def _get_iteration_range(
|
||||
self, iteration_range: Optional[Tuple[int, int]]
|
||||
) -> Tuple[int, int]:
|
||||
self, iteration_range: Optional[IterationRange]
|
||||
) -> IterationRange:
|
||||
if iteration_range is None or iteration_range[1] == 0:
|
||||
# Use best_iteration if defined.
|
||||
try:
|
||||
@ -1057,7 +1057,7 @@ class XGBModel(XGBModelBase):
|
||||
output_margin: bool = False,
|
||||
validate_features: bool = True,
|
||||
base_margin: Optional[ArrayLike] = None,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
iteration_range: Optional[IterationRange] = None,
|
||||
) -> ArrayLike:
|
||||
"""Predict with `X`. If the model is trained with early stopping, then
|
||||
:py:attr:`best_iteration` is used automatically. The estimator uses
|
||||
@ -1129,7 +1129,7 @@ class XGBModel(XGBModelBase):
|
||||
def apply(
|
||||
self,
|
||||
X: ArrayLike,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
iteration_range: Optional[IterationRange] = None,
|
||||
) -> np.ndarray:
|
||||
"""Return the predicted leaf every tree for each sample. If the model is trained
|
||||
with early stopping, then :py:attr:`best_iteration` is used automatically.
|
||||
@ -1465,7 +1465,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
output_margin: bool = False,
|
||||
validate_features: bool = True,
|
||||
base_margin: Optional[ArrayLike] = None,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
iteration_range: Optional[IterationRange] = None,
|
||||
) -> ArrayLike:
|
||||
with config_context(verbosity=self.verbosity):
|
||||
class_probs = super().predict(
|
||||
@ -1500,7 +1500,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
X: ArrayLike,
|
||||
validate_features: bool = True,
|
||||
base_margin: Optional[ArrayLike] = None,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
iteration_range: Optional[IterationRange] = None,
|
||||
) -> np.ndarray:
|
||||
"""Predict the probability of each `X` example being of a given class. If the
|
||||
model is trained with early stopping, then :py:attr:`best_iteration` is used
|
||||
@ -1942,7 +1942,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||
output_margin: bool = False,
|
||||
validate_features: bool = True,
|
||||
base_margin: Optional[ArrayLike] = None,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
iteration_range: Optional[IterationRange] = None,
|
||||
) -> ArrayLike:
|
||||
X, _ = _get_qid(X, None)
|
||||
return super().predict(
|
||||
@ -1956,7 +1956,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||
def apply(
|
||||
self,
|
||||
X: ArrayLike,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
iteration_range: Optional[IterationRange] = None,
|
||||
) -> ArrayLike:
|
||||
X, _ = _get_qid(X, None)
|
||||
return super().apply(X, iteration_range)
|
||||
|
||||
@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
from xgboost.core import Integer
|
||||
from xgboost.testing.updater import ResetStrategy
|
||||
|
||||
dpath = tm.data_dir(__file__)
|
||||
@ -97,15 +98,15 @@ class TestModels:
|
||||
def test_boost_from_prediction(self):
|
||||
# Re-construct dtrain here to avoid modification
|
||||
margined, _ = tm.load_agaricus(__file__)
|
||||
bst = xgb.train({'tree_method': 'hist'}, margined, 1)
|
||||
bst = xgb.train({"tree_method": "hist"}, margined, 1)
|
||||
predt_0 = bst.predict(margined, output_margin=True)
|
||||
margined.set_base_margin(predt_0)
|
||||
bst = xgb.train({'tree_method': 'hist'}, margined, 1)
|
||||
bst = xgb.train({"tree_method": "hist"}, margined, 1)
|
||||
predt_1 = bst.predict(margined)
|
||||
|
||||
assert np.any(np.abs(predt_1 - predt_0) > 1e-6)
|
||||
dtrain, _ = tm.load_agaricus(__file__)
|
||||
bst = xgb.train({'tree_method': 'hist'}, dtrain, 2)
|
||||
bst = xgb.train({"tree_method": "hist"}, dtrain, 2)
|
||||
predt_2 = bst.predict(dtrain)
|
||||
assert np.all(np.abs(predt_2 - predt_1) < 1e-6)
|
||||
|
||||
@ -331,10 +332,15 @@ class TestModels:
|
||||
dtrain: xgb.DMatrix,
|
||||
num_parallel_tree: int,
|
||||
num_classes: int,
|
||||
num_boost_round: int
|
||||
num_boost_round: int,
|
||||
use_np_type: bool,
|
||||
):
|
||||
beg = 3
|
||||
if use_np_type:
|
||||
end: Integer = np.int32(7)
|
||||
else:
|
||||
end = 7
|
||||
|
||||
sliced: xgb.Booster = booster[beg:end]
|
||||
assert sliced.feature_types == booster.feature_types
|
||||
|
||||
@ -345,7 +351,7 @@ class TestModels:
|
||||
sliced = booster[beg:end:2]
|
||||
assert sliced_trees == len(sliced.get_dump())
|
||||
|
||||
sliced = booster[beg: ...]
|
||||
sliced = booster[beg:]
|
||||
sliced_trees = (num_boost_round - beg) * num_parallel_tree * num_classes
|
||||
assert sliced_trees == len(sliced.get_dump())
|
||||
|
||||
@ -357,7 +363,7 @@ class TestModels:
|
||||
sliced_trees = end * num_parallel_tree * num_classes
|
||||
assert sliced_trees == len(sliced.get_dump())
|
||||
|
||||
sliced = booster[...: end]
|
||||
sliced = booster[: end]
|
||||
sliced_trees = end * num_parallel_tree * num_classes
|
||||
assert sliced_trees == len(sliced.get_dump())
|
||||
|
||||
@ -383,14 +389,14 @@ class TestModels:
|
||||
assert len(trees) == num_boost_round
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
booster["wrong type"]
|
||||
booster["wrong type"] # type: ignore
|
||||
with pytest.raises(IndexError):
|
||||
booster[: num_boost_round + 1]
|
||||
with pytest.raises(ValueError):
|
||||
booster[1, 2] # too many dims
|
||||
# setitem is not implemented as model is immutable during slicing.
|
||||
with pytest.raises(TypeError):
|
||||
booster[...: end] = booster
|
||||
booster[:end] = booster # type: ignore
|
||||
|
||||
sliced_0 = booster[1:3]
|
||||
np.testing.assert_allclose(
|
||||
@ -446,15 +452,21 @@ class TestModels:
|
||||
|
||||
assert len(booster.get_dump()) == total_trees
|
||||
|
||||
self.run_slice(booster, dtrain, num_parallel_tree, num_classes, num_boost_round)
|
||||
self.run_slice(
|
||||
booster, dtrain, num_parallel_tree, num_classes, num_boost_round, False
|
||||
)
|
||||
|
||||
bytesarray = booster.save_raw(raw_format="ubj")
|
||||
booster = xgb.Booster(model_file=bytesarray)
|
||||
self.run_slice(booster, dtrain, num_parallel_tree, num_classes, num_boost_round)
|
||||
self.run_slice(
|
||||
booster, dtrain, num_parallel_tree, num_classes, num_boost_round, False
|
||||
)
|
||||
|
||||
bytesarray = booster.save_raw(raw_format="deprecated")
|
||||
booster = xgb.Booster(model_file=bytesarray)
|
||||
self.run_slice(booster, dtrain, num_parallel_tree, num_classes, num_boost_round)
|
||||
self.run_slice(
|
||||
booster, dtrain, num_parallel_tree, num_classes, num_boost_round, True
|
||||
)
|
||||
|
||||
def test_slice_multi(self) -> None:
|
||||
from sklearn.datasets import make_classification
|
||||
@ -479,7 +491,7 @@ class TestModels:
|
||||
},
|
||||
num_boost_round=num_boost_round,
|
||||
dtrain=Xy,
|
||||
callbacks=[ResetStrategy()]
|
||||
callbacks=[ResetStrategy()],
|
||||
)
|
||||
sliced = [t for t in booster]
|
||||
assert len(sliced) == 16
|
||||
|
||||
@ -61,7 +61,7 @@ def run_predict_leaf(device: str) -> np.ndarray:
|
||||
|
||||
validate_leaf_output(leaf, num_parallel_tree)
|
||||
|
||||
n_iters = 2
|
||||
n_iters = np.int32(2)
|
||||
sliced = booster.predict(
|
||||
m,
|
||||
pred_leaf=True,
|
||||
|
||||
@ -440,7 +440,7 @@ def test_regression():
|
||||
preds = xgb_model.predict(X[test_index])
|
||||
# test other params in XGBRegressor().fit
|
||||
preds2 = xgb_model.predict(
|
||||
X[test_index], output_margin=True, iteration_range=(0, 3)
|
||||
X[test_index], output_margin=True, iteration_range=(0, np.int16(3))
|
||||
)
|
||||
preds3 = xgb_model.predict(
|
||||
X[test_index], output_margin=True, iteration_range=None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user