Handle np integer in model slice and prediction. (#10007)
This commit is contained in:
parent
a76d6c6131
commit
65d7bf2dfe
@ -36,6 +36,11 @@ PandasDType = Any # real type is pandas.core.dtypes.base.ExtensionDtype
|
|||||||
|
|
||||||
FloatCompatible = Union[float, np.float32, np.float64]
|
FloatCompatible = Union[float, np.float32, np.float64]
|
||||||
|
|
||||||
|
# typing.SupportsInt is not suitable here since floating point values are convertible to
|
||||||
|
# integers as well.
|
||||||
|
Integer = Union[int, np.integer]
|
||||||
|
IterationRange = Tuple[Integer, Integer]
|
||||||
|
|
||||||
# callables
|
# callables
|
||||||
FPreProcCallable = Callable
|
FPreProcCallable = Callable
|
||||||
|
|
||||||
|
|||||||
@ -48,6 +48,8 @@ from ._typing import (
|
|||||||
FeatureInfo,
|
FeatureInfo,
|
||||||
FeatureNames,
|
FeatureNames,
|
||||||
FeatureTypes,
|
FeatureTypes,
|
||||||
|
Integer,
|
||||||
|
IterationRange,
|
||||||
ModelIn,
|
ModelIn,
|
||||||
NumpyOrCupy,
|
NumpyOrCupy,
|
||||||
TransformedData,
|
TransformedData,
|
||||||
@ -1812,19 +1814,25 @@ class Booster:
|
|||||||
state["handle"] = handle
|
state["handle"] = handle
|
||||||
self.__dict__.update(state)
|
self.__dict__.update(state)
|
||||||
|
|
||||||
def __getitem__(self, val: Union[int, tuple, slice]) -> "Booster":
|
def __getitem__(self, val: Union[Integer, tuple, slice]) -> "Booster":
|
||||||
"""Get a slice of the tree-based model.
|
"""Get a slice of the tree-based model.
|
||||||
|
|
||||||
.. versionadded:: 1.3.0
|
.. versionadded:: 1.3.0
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if isinstance(val, int):
|
# convert to slice for all other types
|
||||||
val = slice(val, val + 1)
|
if isinstance(val, (np.integer, int)):
|
||||||
|
val = slice(int(val), int(val + 1))
|
||||||
|
if isinstance(val, type(Ellipsis)):
|
||||||
|
val = slice(0, 0)
|
||||||
if isinstance(val, tuple):
|
if isinstance(val, tuple):
|
||||||
raise ValueError("Only supports slicing through 1 dimension.")
|
raise ValueError("Only supports slicing through 1 dimension.")
|
||||||
|
# All supported types are now slice
|
||||||
|
# FIXME(jiamingy): Use `types.EllipsisType` once Python 3.10 is used.
|
||||||
if not isinstance(val, slice):
|
if not isinstance(val, slice):
|
||||||
msg = _expect((int, slice), type(val))
|
msg = _expect((int, slice, np.integer, type(Ellipsis)), type(val))
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
|
|
||||||
if isinstance(val.start, type(Ellipsis)) or val.start is None:
|
if isinstance(val.start, type(Ellipsis)) or val.start is None:
|
||||||
start = 0
|
start = 0
|
||||||
else:
|
else:
|
||||||
@ -2246,12 +2254,13 @@ class Booster:
|
|||||||
pred_interactions: bool = False,
|
pred_interactions: bool = False,
|
||||||
validate_features: bool = True,
|
validate_features: bool = True,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
iteration_range: Tuple[int, int] = (0, 0),
|
iteration_range: IterationRange = (0, 0),
|
||||||
strict_shape: bool = False,
|
strict_shape: bool = False,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Predict with data. The full model will be used unless `iteration_range` is specified,
|
"""Predict with data. The full model will be used unless `iteration_range` is
|
||||||
meaning user have to either slice the model or use the ``best_iteration``
|
specified, meaning user have to either slice the model or use the
|
||||||
attribute to get prediction from best model returned from early stopping.
|
``best_iteration`` attribute to get prediction from best model returned from
|
||||||
|
early stopping.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
@ -2336,8 +2345,8 @@ class Booster:
|
|||||||
args = {
|
args = {
|
||||||
"type": 0,
|
"type": 0,
|
||||||
"training": training,
|
"training": training,
|
||||||
"iteration_begin": iteration_range[0],
|
"iteration_begin": int(iteration_range[0]),
|
||||||
"iteration_end": iteration_range[1],
|
"iteration_end": int(iteration_range[1]),
|
||||||
"strict_shape": strict_shape,
|
"strict_shape": strict_shape,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2373,7 +2382,7 @@ class Booster:
|
|||||||
def inplace_predict(
|
def inplace_predict(
|
||||||
self,
|
self,
|
||||||
data: DataType,
|
data: DataType,
|
||||||
iteration_range: Tuple[int, int] = (0, 0),
|
iteration_range: IterationRange = (0, 0),
|
||||||
predict_type: str = "value",
|
predict_type: str = "value",
|
||||||
missing: float = np.nan,
|
missing: float = np.nan,
|
||||||
validate_features: bool = True,
|
validate_features: bool = True,
|
||||||
@ -2439,8 +2448,8 @@ class Booster:
|
|||||||
args = make_jcargs(
|
args = make_jcargs(
|
||||||
type=1 if predict_type == "margin" else 0,
|
type=1 if predict_type == "margin" else 0,
|
||||||
training=False,
|
training=False,
|
||||||
iteration_begin=iteration_range[0],
|
iteration_begin=int(iteration_range[0]),
|
||||||
iteration_end=iteration_range[1],
|
iteration_end=int(iteration_range[1]),
|
||||||
missing=missing,
|
missing=missing,
|
||||||
strict_shape=strict_shape,
|
strict_shape=strict_shape,
|
||||||
cache_id=0,
|
cache_id=0,
|
||||||
|
|||||||
@ -61,7 +61,7 @@ from typing import (
|
|||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from xgboost import collective, config
|
from xgboost import collective, config
|
||||||
from xgboost._typing import _T, FeatureNames, FeatureTypes
|
from xgboost._typing import _T, FeatureNames, FeatureTypes, IterationRange
|
||||||
from xgboost.callback import TrainingCallback
|
from xgboost.callback import TrainingCallback
|
||||||
from xgboost.compat import DataFrame, LazyLoader, concat, lazy_isinstance
|
from xgboost.compat import DataFrame, LazyLoader, concat, lazy_isinstance
|
||||||
from xgboost.core import (
|
from xgboost.core import (
|
||||||
@ -1263,7 +1263,7 @@ async def _predict_async(
|
|||||||
approx_contribs: bool,
|
approx_contribs: bool,
|
||||||
pred_interactions: bool,
|
pred_interactions: bool,
|
||||||
validate_features: bool,
|
validate_features: bool,
|
||||||
iteration_range: Tuple[int, int],
|
iteration_range: IterationRange,
|
||||||
strict_shape: bool,
|
strict_shape: bool,
|
||||||
) -> _DaskCollection:
|
) -> _DaskCollection:
|
||||||
_booster = await _get_model_future(client, model)
|
_booster = await _get_model_future(client, model)
|
||||||
@ -1410,7 +1410,7 @@ def predict( # pylint: disable=unused-argument
|
|||||||
approx_contribs: bool = False,
|
approx_contribs: bool = False,
|
||||||
pred_interactions: bool = False,
|
pred_interactions: bool = False,
|
||||||
validate_features: bool = True,
|
validate_features: bool = True,
|
||||||
iteration_range: Tuple[int, int] = (0, 0),
|
iteration_range: IterationRange = (0, 0),
|
||||||
strict_shape: bool = False,
|
strict_shape: bool = False,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run prediction with a trained booster.
|
"""Run prediction with a trained booster.
|
||||||
@ -1458,7 +1458,7 @@ async def _inplace_predict_async( # pylint: disable=too-many-branches
|
|||||||
global_config: Dict[str, Any],
|
global_config: Dict[str, Any],
|
||||||
model: Union[Booster, Dict, "distributed.Future"],
|
model: Union[Booster, Dict, "distributed.Future"],
|
||||||
data: _DataT,
|
data: _DataT,
|
||||||
iteration_range: Tuple[int, int],
|
iteration_range: IterationRange,
|
||||||
predict_type: str,
|
predict_type: str,
|
||||||
missing: float,
|
missing: float,
|
||||||
validate_features: bool,
|
validate_features: bool,
|
||||||
@ -1516,7 +1516,7 @@ def inplace_predict( # pylint: disable=unused-argument
|
|||||||
client: Optional["distributed.Client"],
|
client: Optional["distributed.Client"],
|
||||||
model: Union[TrainReturnT, Booster, "distributed.Future"],
|
model: Union[TrainReturnT, Booster, "distributed.Future"],
|
||||||
data: _DataT,
|
data: _DataT,
|
||||||
iteration_range: Tuple[int, int] = (0, 0),
|
iteration_range: IterationRange = (0, 0),
|
||||||
predict_type: str = "value",
|
predict_type: str = "value",
|
||||||
missing: float = numpy.nan,
|
missing: float = numpy.nan,
|
||||||
validate_features: bool = True,
|
validate_features: bool = True,
|
||||||
@ -1624,7 +1624,7 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
output_margin: bool,
|
output_margin: bool,
|
||||||
validate_features: bool,
|
validate_features: bool,
|
||||||
base_margin: Optional[_DaskCollection],
|
base_margin: Optional[_DaskCollection],
|
||||||
iteration_range: Optional[Tuple[int, int]],
|
iteration_range: Optional[IterationRange],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
iteration_range = self._get_iteration_range(iteration_range)
|
iteration_range = self._get_iteration_range(iteration_range)
|
||||||
if self._can_use_inplace_predict():
|
if self._can_use_inplace_predict():
|
||||||
@ -1664,7 +1664,7 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
output_margin: bool = False,
|
output_margin: bool = False,
|
||||||
validate_features: bool = True,
|
validate_features: bool = True,
|
||||||
base_margin: Optional[_DaskCollection] = None,
|
base_margin: Optional[_DaskCollection] = None,
|
||||||
iteration_range: Optional[Tuple[int, int]] = None,
|
iteration_range: Optional[IterationRange] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
return self.client.sync(
|
return self.client.sync(
|
||||||
@ -1679,7 +1679,7 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
async def _apply_async(
|
async def _apply_async(
|
||||||
self,
|
self,
|
||||||
X: _DataT,
|
X: _DataT,
|
||||||
iteration_range: Optional[Tuple[int, int]] = None,
|
iteration_range: Optional[IterationRange] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
iteration_range = self._get_iteration_range(iteration_range)
|
iteration_range = self._get_iteration_range(iteration_range)
|
||||||
test_dmatrix = await DaskDMatrix(
|
test_dmatrix = await DaskDMatrix(
|
||||||
@ -1700,7 +1700,7 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
X: _DataT,
|
X: _DataT,
|
||||||
iteration_range: Optional[Tuple[int, int]] = None,
|
iteration_range: Optional[IterationRange] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
return self.client.sync(self._apply_async, X, iteration_range=iteration_range)
|
return self.client.sync(self._apply_async, X, iteration_range=iteration_range)
|
||||||
@ -1962,7 +1962,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
X: _DataT,
|
X: _DataT,
|
||||||
validate_features: bool,
|
validate_features: bool,
|
||||||
base_margin: Optional[_DaskCollection],
|
base_margin: Optional[_DaskCollection],
|
||||||
iteration_range: Optional[Tuple[int, int]],
|
iteration_range: Optional[IterationRange],
|
||||||
) -> _DaskCollection:
|
) -> _DaskCollection:
|
||||||
if self.objective == "multi:softmax":
|
if self.objective == "multi:softmax":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -1987,7 +1987,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
X: _DaskCollection,
|
X: _DaskCollection,
|
||||||
validate_features: bool = True,
|
validate_features: bool = True,
|
||||||
base_margin: Optional[_DaskCollection] = None,
|
base_margin: Optional[_DaskCollection] = None,
|
||||||
iteration_range: Optional[Tuple[int, int]] = None,
|
iteration_range: Optional[IterationRange] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
return self._client_sync(
|
return self._client_sync(
|
||||||
@ -2006,7 +2006,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
output_margin: bool,
|
output_margin: bool,
|
||||||
validate_features: bool,
|
validate_features: bool,
|
||||||
base_margin: Optional[_DaskCollection],
|
base_margin: Optional[_DaskCollection],
|
||||||
iteration_range: Optional[Tuple[int, int]],
|
iteration_range: Optional[IterationRange],
|
||||||
) -> _DaskCollection:
|
) -> _DaskCollection:
|
||||||
pred_probs = await super()._predict_async(
|
pred_probs = await super()._predict_async(
|
||||||
data, output_margin, validate_features, base_margin, iteration_range
|
data, output_margin, validate_features, base_margin, iteration_range
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from typing import (
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.special import softmax
|
from scipy.special import softmax
|
||||||
|
|
||||||
from ._typing import ArrayLike, FeatureNames, FeatureTypes, ModelIn
|
from ._typing import ArrayLike, FeatureNames, FeatureTypes, IterationRange, ModelIn
|
||||||
from .callback import TrainingCallback
|
from .callback import TrainingCallback
|
||||||
|
|
||||||
# Do not use class names on scikit-learn directly. Re-define the classes on
|
# Do not use class names on scikit-learn directly. Re-define the classes on
|
||||||
@ -1039,8 +1039,8 @@ class XGBModel(XGBModelBase):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def _get_iteration_range(
|
def _get_iteration_range(
|
||||||
self, iteration_range: Optional[Tuple[int, int]]
|
self, iteration_range: Optional[IterationRange]
|
||||||
) -> Tuple[int, int]:
|
) -> IterationRange:
|
||||||
if iteration_range is None or iteration_range[1] == 0:
|
if iteration_range is None or iteration_range[1] == 0:
|
||||||
# Use best_iteration if defined.
|
# Use best_iteration if defined.
|
||||||
try:
|
try:
|
||||||
@ -1057,7 +1057,7 @@ class XGBModel(XGBModelBase):
|
|||||||
output_margin: bool = False,
|
output_margin: bool = False,
|
||||||
validate_features: bool = True,
|
validate_features: bool = True,
|
||||||
base_margin: Optional[ArrayLike] = None,
|
base_margin: Optional[ArrayLike] = None,
|
||||||
iteration_range: Optional[Tuple[int, int]] = None,
|
iteration_range: Optional[IterationRange] = None,
|
||||||
) -> ArrayLike:
|
) -> ArrayLike:
|
||||||
"""Predict with `X`. If the model is trained with early stopping, then
|
"""Predict with `X`. If the model is trained with early stopping, then
|
||||||
:py:attr:`best_iteration` is used automatically. The estimator uses
|
:py:attr:`best_iteration` is used automatically. The estimator uses
|
||||||
@ -1129,7 +1129,7 @@ class XGBModel(XGBModelBase):
|
|||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
X: ArrayLike,
|
X: ArrayLike,
|
||||||
iteration_range: Optional[Tuple[int, int]] = None,
|
iteration_range: Optional[IterationRange] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Return the predicted leaf every tree for each sample. If the model is trained
|
"""Return the predicted leaf every tree for each sample. If the model is trained
|
||||||
with early stopping, then :py:attr:`best_iteration` is used automatically.
|
with early stopping, then :py:attr:`best_iteration` is used automatically.
|
||||||
@ -1465,7 +1465,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
output_margin: bool = False,
|
output_margin: bool = False,
|
||||||
validate_features: bool = True,
|
validate_features: bool = True,
|
||||||
base_margin: Optional[ArrayLike] = None,
|
base_margin: Optional[ArrayLike] = None,
|
||||||
iteration_range: Optional[Tuple[int, int]] = None,
|
iteration_range: Optional[IterationRange] = None,
|
||||||
) -> ArrayLike:
|
) -> ArrayLike:
|
||||||
with config_context(verbosity=self.verbosity):
|
with config_context(verbosity=self.verbosity):
|
||||||
class_probs = super().predict(
|
class_probs = super().predict(
|
||||||
@ -1500,7 +1500,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
X: ArrayLike,
|
X: ArrayLike,
|
||||||
validate_features: bool = True,
|
validate_features: bool = True,
|
||||||
base_margin: Optional[ArrayLike] = None,
|
base_margin: Optional[ArrayLike] = None,
|
||||||
iteration_range: Optional[Tuple[int, int]] = None,
|
iteration_range: Optional[IterationRange] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Predict the probability of each `X` example being of a given class. If the
|
"""Predict the probability of each `X` example being of a given class. If the
|
||||||
model is trained with early stopping, then :py:attr:`best_iteration` is used
|
model is trained with early stopping, then :py:attr:`best_iteration` is used
|
||||||
@ -1942,7 +1942,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
output_margin: bool = False,
|
output_margin: bool = False,
|
||||||
validate_features: bool = True,
|
validate_features: bool = True,
|
||||||
base_margin: Optional[ArrayLike] = None,
|
base_margin: Optional[ArrayLike] = None,
|
||||||
iteration_range: Optional[Tuple[int, int]] = None,
|
iteration_range: Optional[IterationRange] = None,
|
||||||
) -> ArrayLike:
|
) -> ArrayLike:
|
||||||
X, _ = _get_qid(X, None)
|
X, _ = _get_qid(X, None)
|
||||||
return super().predict(
|
return super().predict(
|
||||||
@ -1956,7 +1956,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
X: ArrayLike,
|
X: ArrayLike,
|
||||||
iteration_range: Optional[Tuple[int, int]] = None,
|
iteration_range: Optional[IterationRange] = None,
|
||||||
) -> ArrayLike:
|
) -> ArrayLike:
|
||||||
X, _ = _get_qid(X, None)
|
X, _ = _get_qid(X, None)
|
||||||
return super().apply(X, iteration_range)
|
return super().apply(X, iteration_range)
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import pytest
|
|||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost import testing as tm
|
from xgboost import testing as tm
|
||||||
|
from xgboost.core import Integer
|
||||||
from xgboost.testing.updater import ResetStrategy
|
from xgboost.testing.updater import ResetStrategy
|
||||||
|
|
||||||
dpath = tm.data_dir(__file__)
|
dpath = tm.data_dir(__file__)
|
||||||
@ -97,15 +98,15 @@ class TestModels:
|
|||||||
def test_boost_from_prediction(self):
|
def test_boost_from_prediction(self):
|
||||||
# Re-construct dtrain here to avoid modification
|
# Re-construct dtrain here to avoid modification
|
||||||
margined, _ = tm.load_agaricus(__file__)
|
margined, _ = tm.load_agaricus(__file__)
|
||||||
bst = xgb.train({'tree_method': 'hist'}, margined, 1)
|
bst = xgb.train({"tree_method": "hist"}, margined, 1)
|
||||||
predt_0 = bst.predict(margined, output_margin=True)
|
predt_0 = bst.predict(margined, output_margin=True)
|
||||||
margined.set_base_margin(predt_0)
|
margined.set_base_margin(predt_0)
|
||||||
bst = xgb.train({'tree_method': 'hist'}, margined, 1)
|
bst = xgb.train({"tree_method": "hist"}, margined, 1)
|
||||||
predt_1 = bst.predict(margined)
|
predt_1 = bst.predict(margined)
|
||||||
|
|
||||||
assert np.any(np.abs(predt_1 - predt_0) > 1e-6)
|
assert np.any(np.abs(predt_1 - predt_0) > 1e-6)
|
||||||
dtrain, _ = tm.load_agaricus(__file__)
|
dtrain, _ = tm.load_agaricus(__file__)
|
||||||
bst = xgb.train({'tree_method': 'hist'}, dtrain, 2)
|
bst = xgb.train({"tree_method": "hist"}, dtrain, 2)
|
||||||
predt_2 = bst.predict(dtrain)
|
predt_2 = bst.predict(dtrain)
|
||||||
assert np.all(np.abs(predt_2 - predt_1) < 1e-6)
|
assert np.all(np.abs(predt_2 - predt_1) < 1e-6)
|
||||||
|
|
||||||
@ -331,10 +332,15 @@ class TestModels:
|
|||||||
dtrain: xgb.DMatrix,
|
dtrain: xgb.DMatrix,
|
||||||
num_parallel_tree: int,
|
num_parallel_tree: int,
|
||||||
num_classes: int,
|
num_classes: int,
|
||||||
num_boost_round: int
|
num_boost_round: int,
|
||||||
|
use_np_type: bool,
|
||||||
):
|
):
|
||||||
beg = 3
|
beg = 3
|
||||||
end = 7
|
if use_np_type:
|
||||||
|
end: Integer = np.int32(7)
|
||||||
|
else:
|
||||||
|
end = 7
|
||||||
|
|
||||||
sliced: xgb.Booster = booster[beg:end]
|
sliced: xgb.Booster = booster[beg:end]
|
||||||
assert sliced.feature_types == booster.feature_types
|
assert sliced.feature_types == booster.feature_types
|
||||||
|
|
||||||
@ -345,7 +351,7 @@ class TestModels:
|
|||||||
sliced = booster[beg:end:2]
|
sliced = booster[beg:end:2]
|
||||||
assert sliced_trees == len(sliced.get_dump())
|
assert sliced_trees == len(sliced.get_dump())
|
||||||
|
|
||||||
sliced = booster[beg: ...]
|
sliced = booster[beg:]
|
||||||
sliced_trees = (num_boost_round - beg) * num_parallel_tree * num_classes
|
sliced_trees = (num_boost_round - beg) * num_parallel_tree * num_classes
|
||||||
assert sliced_trees == len(sliced.get_dump())
|
assert sliced_trees == len(sliced.get_dump())
|
||||||
|
|
||||||
@ -357,7 +363,7 @@ class TestModels:
|
|||||||
sliced_trees = end * num_parallel_tree * num_classes
|
sliced_trees = end * num_parallel_tree * num_classes
|
||||||
assert sliced_trees == len(sliced.get_dump())
|
assert sliced_trees == len(sliced.get_dump())
|
||||||
|
|
||||||
sliced = booster[...: end]
|
sliced = booster[: end]
|
||||||
sliced_trees = end * num_parallel_tree * num_classes
|
sliced_trees = end * num_parallel_tree * num_classes
|
||||||
assert sliced_trees == len(sliced.get_dump())
|
assert sliced_trees == len(sliced.get_dump())
|
||||||
|
|
||||||
@ -383,14 +389,14 @@ class TestModels:
|
|||||||
assert len(trees) == num_boost_round
|
assert len(trees) == num_boost_round
|
||||||
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
booster["wrong type"]
|
booster["wrong type"] # type: ignore
|
||||||
with pytest.raises(IndexError):
|
with pytest.raises(IndexError):
|
||||||
booster[: num_boost_round + 1]
|
booster[: num_boost_round + 1]
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
booster[1, 2] # too many dims
|
booster[1, 2] # too many dims
|
||||||
# setitem is not implemented as model is immutable during slicing.
|
# setitem is not implemented as model is immutable during slicing.
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
booster[...: end] = booster
|
booster[:end] = booster # type: ignore
|
||||||
|
|
||||||
sliced_0 = booster[1:3]
|
sliced_0 = booster[1:3]
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
@ -446,15 +452,21 @@ class TestModels:
|
|||||||
|
|
||||||
assert len(booster.get_dump()) == total_trees
|
assert len(booster.get_dump()) == total_trees
|
||||||
|
|
||||||
self.run_slice(booster, dtrain, num_parallel_tree, num_classes, num_boost_round)
|
self.run_slice(
|
||||||
|
booster, dtrain, num_parallel_tree, num_classes, num_boost_round, False
|
||||||
|
)
|
||||||
|
|
||||||
bytesarray = booster.save_raw(raw_format="ubj")
|
bytesarray = booster.save_raw(raw_format="ubj")
|
||||||
booster = xgb.Booster(model_file=bytesarray)
|
booster = xgb.Booster(model_file=bytesarray)
|
||||||
self.run_slice(booster, dtrain, num_parallel_tree, num_classes, num_boost_round)
|
self.run_slice(
|
||||||
|
booster, dtrain, num_parallel_tree, num_classes, num_boost_round, False
|
||||||
|
)
|
||||||
|
|
||||||
bytesarray = booster.save_raw(raw_format="deprecated")
|
bytesarray = booster.save_raw(raw_format="deprecated")
|
||||||
booster = xgb.Booster(model_file=bytesarray)
|
booster = xgb.Booster(model_file=bytesarray)
|
||||||
self.run_slice(booster, dtrain, num_parallel_tree, num_classes, num_boost_round)
|
self.run_slice(
|
||||||
|
booster, dtrain, num_parallel_tree, num_classes, num_boost_round, True
|
||||||
|
)
|
||||||
|
|
||||||
def test_slice_multi(self) -> None:
|
def test_slice_multi(self) -> None:
|
||||||
from sklearn.datasets import make_classification
|
from sklearn.datasets import make_classification
|
||||||
@ -479,7 +491,7 @@ class TestModels:
|
|||||||
},
|
},
|
||||||
num_boost_round=num_boost_round,
|
num_boost_round=num_boost_round,
|
||||||
dtrain=Xy,
|
dtrain=Xy,
|
||||||
callbacks=[ResetStrategy()]
|
callbacks=[ResetStrategy()],
|
||||||
)
|
)
|
||||||
sliced = [t for t in booster]
|
sliced = [t for t in booster]
|
||||||
assert len(sliced) == 16
|
assert len(sliced) == 16
|
||||||
|
|||||||
@ -61,7 +61,7 @@ def run_predict_leaf(device: str) -> np.ndarray:
|
|||||||
|
|
||||||
validate_leaf_output(leaf, num_parallel_tree)
|
validate_leaf_output(leaf, num_parallel_tree)
|
||||||
|
|
||||||
n_iters = 2
|
n_iters = np.int32(2)
|
||||||
sliced = booster.predict(
|
sliced = booster.predict(
|
||||||
m,
|
m,
|
||||||
pred_leaf=True,
|
pred_leaf=True,
|
||||||
|
|||||||
@ -440,7 +440,7 @@ def test_regression():
|
|||||||
preds = xgb_model.predict(X[test_index])
|
preds = xgb_model.predict(X[test_index])
|
||||||
# test other params in XGBRegressor().fit
|
# test other params in XGBRegressor().fit
|
||||||
preds2 = xgb_model.predict(
|
preds2 = xgb_model.predict(
|
||||||
X[test_index], output_margin=True, iteration_range=(0, 3)
|
X[test_index], output_margin=True, iteration_range=(0, np.int16(3))
|
||||||
)
|
)
|
||||||
preds3 = xgb_model.predict(
|
preds3 = xgb_model.predict(
|
||||||
X[test_index], output_margin=True, iteration_range=None
|
X[test_index], output_margin=True, iteration_range=None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user