Handle np integer in model slice and prediction. (#10007)

This commit is contained in:
Jiaming Yuan 2024-01-26 04:58:48 +08:00 committed by GitHub
parent a76d6c6131
commit 65d7bf2dfe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 75 additions and 49 deletions

View File

@ -36,6 +36,11 @@ PandasDType = Any # real type is pandas.core.dtypes.base.ExtensionDtype
FloatCompatible = Union[float, np.float32, np.float64] FloatCompatible = Union[float, np.float32, np.float64]
# typing.SupportsInt is not suitable here since floating point values are convertible to
# integers as well.
Integer = Union[int, np.integer]
IterationRange = Tuple[Integer, Integer]
# callables # callables
FPreProcCallable = Callable FPreProcCallable = Callable

View File

@ -48,6 +48,8 @@ from ._typing import (
FeatureInfo, FeatureInfo,
FeatureNames, FeatureNames,
FeatureTypes, FeatureTypes,
Integer,
IterationRange,
ModelIn, ModelIn,
NumpyOrCupy, NumpyOrCupy,
TransformedData, TransformedData,
@ -1812,19 +1814,25 @@ class Booster:
state["handle"] = handle state["handle"] = handle
self.__dict__.update(state) self.__dict__.update(state)
def __getitem__(self, val: Union[int, tuple, slice]) -> "Booster": def __getitem__(self, val: Union[Integer, tuple, slice]) -> "Booster":
"""Get a slice of the tree-based model. """Get a slice of the tree-based model.
.. versionadded:: 1.3.0 .. versionadded:: 1.3.0
""" """
if isinstance(val, int): # convert to slice for all other types
val = slice(val, val + 1) if isinstance(val, (np.integer, int)):
val = slice(int(val), int(val + 1))
if isinstance(val, type(Ellipsis)):
val = slice(0, 0)
if isinstance(val, tuple): if isinstance(val, tuple):
raise ValueError("Only supports slicing through 1 dimension.") raise ValueError("Only supports slicing through 1 dimension.")
# All supported types are now slice
# FIXME(jiamingy): Use `types.EllipsisType` once Python 3.10 is used.
if not isinstance(val, slice): if not isinstance(val, slice):
msg = _expect((int, slice), type(val)) msg = _expect((int, slice, np.integer, type(Ellipsis)), type(val))
raise TypeError(msg) raise TypeError(msg)
if isinstance(val.start, type(Ellipsis)) or val.start is None: if isinstance(val.start, type(Ellipsis)) or val.start is None:
start = 0 start = 0
else: else:
@ -2246,12 +2254,13 @@ class Booster:
pred_interactions: bool = False, pred_interactions: bool = False,
validate_features: bool = True, validate_features: bool = True,
training: bool = False, training: bool = False,
iteration_range: Tuple[int, int] = (0, 0), iteration_range: IterationRange = (0, 0),
strict_shape: bool = False, strict_shape: bool = False,
) -> np.ndarray: ) -> np.ndarray:
"""Predict with data. The full model will be used unless `iteration_range` is specified, """Predict with data. The full model will be used unless `iteration_range` is
meaning user have to either slice the model or use the ``best_iteration`` specified, meaning user have to either slice the model or use the
attribute to get prediction from best model returned from early stopping. ``best_iteration`` attribute to get prediction from best model returned from
early stopping.
.. note:: .. note::
@ -2336,8 +2345,8 @@ class Booster:
args = { args = {
"type": 0, "type": 0,
"training": training, "training": training,
"iteration_begin": iteration_range[0], "iteration_begin": int(iteration_range[0]),
"iteration_end": iteration_range[1], "iteration_end": int(iteration_range[1]),
"strict_shape": strict_shape, "strict_shape": strict_shape,
} }
@ -2373,7 +2382,7 @@ class Booster:
def inplace_predict( def inplace_predict(
self, self,
data: DataType, data: DataType,
iteration_range: Tuple[int, int] = (0, 0), iteration_range: IterationRange = (0, 0),
predict_type: str = "value", predict_type: str = "value",
missing: float = np.nan, missing: float = np.nan,
validate_features: bool = True, validate_features: bool = True,
@ -2439,8 +2448,8 @@ class Booster:
args = make_jcargs( args = make_jcargs(
type=1 if predict_type == "margin" else 0, type=1 if predict_type == "margin" else 0,
training=False, training=False,
iteration_begin=iteration_range[0], iteration_begin=int(iteration_range[0]),
iteration_end=iteration_range[1], iteration_end=int(iteration_range[1]),
missing=missing, missing=missing,
strict_shape=strict_shape, strict_shape=strict_shape,
cache_id=0, cache_id=0,

View File

@ -61,7 +61,7 @@ from typing import (
import numpy import numpy
from xgboost import collective, config from xgboost import collective, config
from xgboost._typing import _T, FeatureNames, FeatureTypes from xgboost._typing import _T, FeatureNames, FeatureTypes, IterationRange
from xgboost.callback import TrainingCallback from xgboost.callback import TrainingCallback
from xgboost.compat import DataFrame, LazyLoader, concat, lazy_isinstance from xgboost.compat import DataFrame, LazyLoader, concat, lazy_isinstance
from xgboost.core import ( from xgboost.core import (
@ -1263,7 +1263,7 @@ async def _predict_async(
approx_contribs: bool, approx_contribs: bool,
pred_interactions: bool, pred_interactions: bool,
validate_features: bool, validate_features: bool,
iteration_range: Tuple[int, int], iteration_range: IterationRange,
strict_shape: bool, strict_shape: bool,
) -> _DaskCollection: ) -> _DaskCollection:
_booster = await _get_model_future(client, model) _booster = await _get_model_future(client, model)
@ -1410,7 +1410,7 @@ def predict( # pylint: disable=unused-argument
approx_contribs: bool = False, approx_contribs: bool = False,
pred_interactions: bool = False, pred_interactions: bool = False,
validate_features: bool = True, validate_features: bool = True,
iteration_range: Tuple[int, int] = (0, 0), iteration_range: IterationRange = (0, 0),
strict_shape: bool = False, strict_shape: bool = False,
) -> Any: ) -> Any:
"""Run prediction with a trained booster. """Run prediction with a trained booster.
@ -1458,7 +1458,7 @@ async def _inplace_predict_async( # pylint: disable=too-many-branches
global_config: Dict[str, Any], global_config: Dict[str, Any],
model: Union[Booster, Dict, "distributed.Future"], model: Union[Booster, Dict, "distributed.Future"],
data: _DataT, data: _DataT,
iteration_range: Tuple[int, int], iteration_range: IterationRange,
predict_type: str, predict_type: str,
missing: float, missing: float,
validate_features: bool, validate_features: bool,
@ -1516,7 +1516,7 @@ def inplace_predict( # pylint: disable=unused-argument
client: Optional["distributed.Client"], client: Optional["distributed.Client"],
model: Union[TrainReturnT, Booster, "distributed.Future"], model: Union[TrainReturnT, Booster, "distributed.Future"],
data: _DataT, data: _DataT,
iteration_range: Tuple[int, int] = (0, 0), iteration_range: IterationRange = (0, 0),
predict_type: str = "value", predict_type: str = "value",
missing: float = numpy.nan, missing: float = numpy.nan,
validate_features: bool = True, validate_features: bool = True,
@ -1624,7 +1624,7 @@ class DaskScikitLearnBase(XGBModel):
output_margin: bool, output_margin: bool,
validate_features: bool, validate_features: bool,
base_margin: Optional[_DaskCollection], base_margin: Optional[_DaskCollection],
iteration_range: Optional[Tuple[int, int]], iteration_range: Optional[IterationRange],
) -> Any: ) -> Any:
iteration_range = self._get_iteration_range(iteration_range) iteration_range = self._get_iteration_range(iteration_range)
if self._can_use_inplace_predict(): if self._can_use_inplace_predict():
@ -1664,7 +1664,7 @@ class DaskScikitLearnBase(XGBModel):
output_margin: bool = False, output_margin: bool = False,
validate_features: bool = True, validate_features: bool = True,
base_margin: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None,
iteration_range: Optional[Tuple[int, int]] = None, iteration_range: Optional[IterationRange] = None,
) -> Any: ) -> Any:
_assert_dask_support() _assert_dask_support()
return self.client.sync( return self.client.sync(
@ -1679,7 +1679,7 @@ class DaskScikitLearnBase(XGBModel):
async def _apply_async( async def _apply_async(
self, self,
X: _DataT, X: _DataT,
iteration_range: Optional[Tuple[int, int]] = None, iteration_range: Optional[IterationRange] = None,
) -> Any: ) -> Any:
iteration_range = self._get_iteration_range(iteration_range) iteration_range = self._get_iteration_range(iteration_range)
test_dmatrix = await DaskDMatrix( test_dmatrix = await DaskDMatrix(
@ -1700,7 +1700,7 @@ class DaskScikitLearnBase(XGBModel):
def apply( def apply(
self, self,
X: _DataT, X: _DataT,
iteration_range: Optional[Tuple[int, int]] = None, iteration_range: Optional[IterationRange] = None,
) -> Any: ) -> Any:
_assert_dask_support() _assert_dask_support()
return self.client.sync(self._apply_async, X, iteration_range=iteration_range) return self.client.sync(self._apply_async, X, iteration_range=iteration_range)
@ -1962,7 +1962,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
X: _DataT, X: _DataT,
validate_features: bool, validate_features: bool,
base_margin: Optional[_DaskCollection], base_margin: Optional[_DaskCollection],
iteration_range: Optional[Tuple[int, int]], iteration_range: Optional[IterationRange],
) -> _DaskCollection: ) -> _DaskCollection:
if self.objective == "multi:softmax": if self.objective == "multi:softmax":
raise ValueError( raise ValueError(
@ -1987,7 +1987,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
X: _DaskCollection, X: _DaskCollection,
validate_features: bool = True, validate_features: bool = True,
base_margin: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None,
iteration_range: Optional[Tuple[int, int]] = None, iteration_range: Optional[IterationRange] = None,
) -> Any: ) -> Any:
_assert_dask_support() _assert_dask_support()
return self._client_sync( return self._client_sync(
@ -2006,7 +2006,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
output_margin: bool, output_margin: bool,
validate_features: bool, validate_features: bool,
base_margin: Optional[_DaskCollection], base_margin: Optional[_DaskCollection],
iteration_range: Optional[Tuple[int, int]], iteration_range: Optional[IterationRange],
) -> _DaskCollection: ) -> _DaskCollection:
pred_probs = await super()._predict_async( pred_probs = await super()._predict_async(
data, output_margin, validate_features, base_margin, iteration_range data, output_margin, validate_features, base_margin, iteration_range

View File

@ -22,7 +22,7 @@ from typing import (
import numpy as np import numpy as np
from scipy.special import softmax from scipy.special import softmax
from ._typing import ArrayLike, FeatureNames, FeatureTypes, ModelIn from ._typing import ArrayLike, FeatureNames, FeatureTypes, IterationRange, ModelIn
from .callback import TrainingCallback from .callback import TrainingCallback
# Do not use class names on scikit-learn directly. Re-define the classes on # Do not use class names on scikit-learn directly. Re-define the classes on
@ -1039,8 +1039,8 @@ class XGBModel(XGBModelBase):
return False return False
def _get_iteration_range( def _get_iteration_range(
self, iteration_range: Optional[Tuple[int, int]] self, iteration_range: Optional[IterationRange]
) -> Tuple[int, int]: ) -> IterationRange:
if iteration_range is None or iteration_range[1] == 0: if iteration_range is None or iteration_range[1] == 0:
# Use best_iteration if defined. # Use best_iteration if defined.
try: try:
@ -1057,7 +1057,7 @@ class XGBModel(XGBModelBase):
output_margin: bool = False, output_margin: bool = False,
validate_features: bool = True, validate_features: bool = True,
base_margin: Optional[ArrayLike] = None, base_margin: Optional[ArrayLike] = None,
iteration_range: Optional[Tuple[int, int]] = None, iteration_range: Optional[IterationRange] = None,
) -> ArrayLike: ) -> ArrayLike:
"""Predict with `X`. If the model is trained with early stopping, then """Predict with `X`. If the model is trained with early stopping, then
:py:attr:`best_iteration` is used automatically. The estimator uses :py:attr:`best_iteration` is used automatically. The estimator uses
@ -1129,7 +1129,7 @@ class XGBModel(XGBModelBase):
def apply( def apply(
self, self,
X: ArrayLike, X: ArrayLike,
iteration_range: Optional[Tuple[int, int]] = None, iteration_range: Optional[IterationRange] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Return the predicted leaf every tree for each sample. If the model is trained """Return the predicted leaf every tree for each sample. If the model is trained
with early stopping, then :py:attr:`best_iteration` is used automatically. with early stopping, then :py:attr:`best_iteration` is used automatically.
@ -1465,7 +1465,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
output_margin: bool = False, output_margin: bool = False,
validate_features: bool = True, validate_features: bool = True,
base_margin: Optional[ArrayLike] = None, base_margin: Optional[ArrayLike] = None,
iteration_range: Optional[Tuple[int, int]] = None, iteration_range: Optional[IterationRange] = None,
) -> ArrayLike: ) -> ArrayLike:
with config_context(verbosity=self.verbosity): with config_context(verbosity=self.verbosity):
class_probs = super().predict( class_probs = super().predict(
@ -1500,7 +1500,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
X: ArrayLike, X: ArrayLike,
validate_features: bool = True, validate_features: bool = True,
base_margin: Optional[ArrayLike] = None, base_margin: Optional[ArrayLike] = None,
iteration_range: Optional[Tuple[int, int]] = None, iteration_range: Optional[IterationRange] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Predict the probability of each `X` example being of a given class. If the """Predict the probability of each `X` example being of a given class. If the
model is trained with early stopping, then :py:attr:`best_iteration` is used model is trained with early stopping, then :py:attr:`best_iteration` is used
@ -1942,7 +1942,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
output_margin: bool = False, output_margin: bool = False,
validate_features: bool = True, validate_features: bool = True,
base_margin: Optional[ArrayLike] = None, base_margin: Optional[ArrayLike] = None,
iteration_range: Optional[Tuple[int, int]] = None, iteration_range: Optional[IterationRange] = None,
) -> ArrayLike: ) -> ArrayLike:
X, _ = _get_qid(X, None) X, _ = _get_qid(X, None)
return super().predict( return super().predict(
@ -1956,7 +1956,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
def apply( def apply(
self, self,
X: ArrayLike, X: ArrayLike,
iteration_range: Optional[Tuple[int, int]] = None, iteration_range: Optional[IterationRange] = None,
) -> ArrayLike: ) -> ArrayLike:
X, _ = _get_qid(X, None) X, _ = _get_qid(X, None)
return super().apply(X, iteration_range) return super().apply(X, iteration_range)

View File

@ -7,6 +7,7 @@ import pytest
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
from xgboost.core import Integer
from xgboost.testing.updater import ResetStrategy from xgboost.testing.updater import ResetStrategy
dpath = tm.data_dir(__file__) dpath = tm.data_dir(__file__)
@ -97,15 +98,15 @@ class TestModels:
def test_boost_from_prediction(self): def test_boost_from_prediction(self):
# Re-construct dtrain here to avoid modification # Re-construct dtrain here to avoid modification
margined, _ = tm.load_agaricus(__file__) margined, _ = tm.load_agaricus(__file__)
bst = xgb.train({'tree_method': 'hist'}, margined, 1) bst = xgb.train({"tree_method": "hist"}, margined, 1)
predt_0 = bst.predict(margined, output_margin=True) predt_0 = bst.predict(margined, output_margin=True)
margined.set_base_margin(predt_0) margined.set_base_margin(predt_0)
bst = xgb.train({'tree_method': 'hist'}, margined, 1) bst = xgb.train({"tree_method": "hist"}, margined, 1)
predt_1 = bst.predict(margined) predt_1 = bst.predict(margined)
assert np.any(np.abs(predt_1 - predt_0) > 1e-6) assert np.any(np.abs(predt_1 - predt_0) > 1e-6)
dtrain, _ = tm.load_agaricus(__file__) dtrain, _ = tm.load_agaricus(__file__)
bst = xgb.train({'tree_method': 'hist'}, dtrain, 2) bst = xgb.train({"tree_method": "hist"}, dtrain, 2)
predt_2 = bst.predict(dtrain) predt_2 = bst.predict(dtrain)
assert np.all(np.abs(predt_2 - predt_1) < 1e-6) assert np.all(np.abs(predt_2 - predt_1) < 1e-6)
@ -331,10 +332,15 @@ class TestModels:
dtrain: xgb.DMatrix, dtrain: xgb.DMatrix,
num_parallel_tree: int, num_parallel_tree: int,
num_classes: int, num_classes: int,
num_boost_round: int num_boost_round: int,
use_np_type: bool,
): ):
beg = 3 beg = 3
if use_np_type:
end: Integer = np.int32(7)
else:
end = 7 end = 7
sliced: xgb.Booster = booster[beg:end] sliced: xgb.Booster = booster[beg:end]
assert sliced.feature_types == booster.feature_types assert sliced.feature_types == booster.feature_types
@ -345,7 +351,7 @@ class TestModels:
sliced = booster[beg:end:2] sliced = booster[beg:end:2]
assert sliced_trees == len(sliced.get_dump()) assert sliced_trees == len(sliced.get_dump())
sliced = booster[beg: ...] sliced = booster[beg:]
sliced_trees = (num_boost_round - beg) * num_parallel_tree * num_classes sliced_trees = (num_boost_round - beg) * num_parallel_tree * num_classes
assert sliced_trees == len(sliced.get_dump()) assert sliced_trees == len(sliced.get_dump())
@ -357,7 +363,7 @@ class TestModels:
sliced_trees = end * num_parallel_tree * num_classes sliced_trees = end * num_parallel_tree * num_classes
assert sliced_trees == len(sliced.get_dump()) assert sliced_trees == len(sliced.get_dump())
sliced = booster[...: end] sliced = booster[: end]
sliced_trees = end * num_parallel_tree * num_classes sliced_trees = end * num_parallel_tree * num_classes
assert sliced_trees == len(sliced.get_dump()) assert sliced_trees == len(sliced.get_dump())
@ -383,14 +389,14 @@ class TestModels:
assert len(trees) == num_boost_round assert len(trees) == num_boost_round
with pytest.raises(TypeError): with pytest.raises(TypeError):
booster["wrong type"] booster["wrong type"] # type: ignore
with pytest.raises(IndexError): with pytest.raises(IndexError):
booster[: num_boost_round + 1] booster[: num_boost_round + 1]
with pytest.raises(ValueError): with pytest.raises(ValueError):
booster[1, 2] # too many dims booster[1, 2] # too many dims
# setitem is not implemented as model is immutable during slicing. # setitem is not implemented as model is immutable during slicing.
with pytest.raises(TypeError): with pytest.raises(TypeError):
booster[...: end] = booster booster[:end] = booster # type: ignore
sliced_0 = booster[1:3] sliced_0 = booster[1:3]
np.testing.assert_allclose( np.testing.assert_allclose(
@ -446,15 +452,21 @@ class TestModels:
assert len(booster.get_dump()) == total_trees assert len(booster.get_dump()) == total_trees
self.run_slice(booster, dtrain, num_parallel_tree, num_classes, num_boost_round) self.run_slice(
booster, dtrain, num_parallel_tree, num_classes, num_boost_round, False
)
bytesarray = booster.save_raw(raw_format="ubj") bytesarray = booster.save_raw(raw_format="ubj")
booster = xgb.Booster(model_file=bytesarray) booster = xgb.Booster(model_file=bytesarray)
self.run_slice(booster, dtrain, num_parallel_tree, num_classes, num_boost_round) self.run_slice(
booster, dtrain, num_parallel_tree, num_classes, num_boost_round, False
)
bytesarray = booster.save_raw(raw_format="deprecated") bytesarray = booster.save_raw(raw_format="deprecated")
booster = xgb.Booster(model_file=bytesarray) booster = xgb.Booster(model_file=bytesarray)
self.run_slice(booster, dtrain, num_parallel_tree, num_classes, num_boost_round) self.run_slice(
booster, dtrain, num_parallel_tree, num_classes, num_boost_round, True
)
def test_slice_multi(self) -> None: def test_slice_multi(self) -> None:
from sklearn.datasets import make_classification from sklearn.datasets import make_classification
@ -479,7 +491,7 @@ class TestModels:
}, },
num_boost_round=num_boost_round, num_boost_round=num_boost_round,
dtrain=Xy, dtrain=Xy,
callbacks=[ResetStrategy()] callbacks=[ResetStrategy()],
) )
sliced = [t for t in booster] sliced = [t for t in booster]
assert len(sliced) == 16 assert len(sliced) == 16

View File

@ -61,7 +61,7 @@ def run_predict_leaf(device: str) -> np.ndarray:
validate_leaf_output(leaf, num_parallel_tree) validate_leaf_output(leaf, num_parallel_tree)
n_iters = 2 n_iters = np.int32(2)
sliced = booster.predict( sliced = booster.predict(
m, m,
pred_leaf=True, pred_leaf=True,

View File

@ -440,7 +440,7 @@ def test_regression():
preds = xgb_model.predict(X[test_index]) preds = xgb_model.predict(X[test_index])
# test other params in XGBRegressor().fit # test other params in XGBRegressor().fit
preds2 = xgb_model.predict( preds2 = xgb_model.predict(
X[test_index], output_margin=True, iteration_range=(0, 3) X[test_index], output_margin=True, iteration_range=(0, np.int16(3))
) )
preds3 = xgb_model.predict( preds3 = xgb_model.predict(
X[test_index], output_margin=True, iteration_range=None X[test_index], output_margin=True, iteration_range=None