Handle np integer in model slice and prediction. (#10007)
This commit is contained in:
@@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
from xgboost.core import Integer
|
||||
from xgboost.testing.updater import ResetStrategy
|
||||
|
||||
dpath = tm.data_dir(__file__)
|
||||
@@ -97,15 +98,15 @@ class TestModels:
|
||||
def test_boost_from_prediction(self):
|
||||
# Re-construct dtrain here to avoid modification
|
||||
margined, _ = tm.load_agaricus(__file__)
|
||||
bst = xgb.train({'tree_method': 'hist'}, margined, 1)
|
||||
bst = xgb.train({"tree_method": "hist"}, margined, 1)
|
||||
predt_0 = bst.predict(margined, output_margin=True)
|
||||
margined.set_base_margin(predt_0)
|
||||
bst = xgb.train({'tree_method': 'hist'}, margined, 1)
|
||||
bst = xgb.train({"tree_method": "hist"}, margined, 1)
|
||||
predt_1 = bst.predict(margined)
|
||||
|
||||
assert np.any(np.abs(predt_1 - predt_0) > 1e-6)
|
||||
dtrain, _ = tm.load_agaricus(__file__)
|
||||
bst = xgb.train({'tree_method': 'hist'}, dtrain, 2)
|
||||
bst = xgb.train({"tree_method": "hist"}, dtrain, 2)
|
||||
predt_2 = bst.predict(dtrain)
|
||||
assert np.all(np.abs(predt_2 - predt_1) < 1e-6)
|
||||
|
||||
@@ -331,10 +332,15 @@ class TestModels:
|
||||
dtrain: xgb.DMatrix,
|
||||
num_parallel_tree: int,
|
||||
num_classes: int,
|
||||
num_boost_round: int
|
||||
num_boost_round: int,
|
||||
use_np_type: bool,
|
||||
):
|
||||
beg = 3
|
||||
end = 7
|
||||
if use_np_type:
|
||||
end: Integer = np.int32(7)
|
||||
else:
|
||||
end = 7
|
||||
|
||||
sliced: xgb.Booster = booster[beg:end]
|
||||
assert sliced.feature_types == booster.feature_types
|
||||
|
||||
@@ -345,7 +351,7 @@ class TestModels:
|
||||
sliced = booster[beg:end:2]
|
||||
assert sliced_trees == len(sliced.get_dump())
|
||||
|
||||
sliced = booster[beg: ...]
|
||||
sliced = booster[beg:]
|
||||
sliced_trees = (num_boost_round - beg) * num_parallel_tree * num_classes
|
||||
assert sliced_trees == len(sliced.get_dump())
|
||||
|
||||
@@ -357,7 +363,7 @@ class TestModels:
|
||||
sliced_trees = end * num_parallel_tree * num_classes
|
||||
assert sliced_trees == len(sliced.get_dump())
|
||||
|
||||
sliced = booster[...: end]
|
||||
sliced = booster[: end]
|
||||
sliced_trees = end * num_parallel_tree * num_classes
|
||||
assert sliced_trees == len(sliced.get_dump())
|
||||
|
||||
@@ -383,14 +389,14 @@ class TestModels:
|
||||
assert len(trees) == num_boost_round
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
booster["wrong type"]
|
||||
booster["wrong type"] # type: ignore
|
||||
with pytest.raises(IndexError):
|
||||
booster[: num_boost_round + 1]
|
||||
with pytest.raises(ValueError):
|
||||
booster[1, 2] # too many dims
|
||||
# setitem is not implemented as model is immutable during slicing.
|
||||
with pytest.raises(TypeError):
|
||||
booster[...: end] = booster
|
||||
booster[:end] = booster # type: ignore
|
||||
|
||||
sliced_0 = booster[1:3]
|
||||
np.testing.assert_allclose(
|
||||
@@ -446,15 +452,21 @@ class TestModels:
|
||||
|
||||
assert len(booster.get_dump()) == total_trees
|
||||
|
||||
self.run_slice(booster, dtrain, num_parallel_tree, num_classes, num_boost_round)
|
||||
self.run_slice(
|
||||
booster, dtrain, num_parallel_tree, num_classes, num_boost_round, False
|
||||
)
|
||||
|
||||
bytesarray = booster.save_raw(raw_format="ubj")
|
||||
booster = xgb.Booster(model_file=bytesarray)
|
||||
self.run_slice(booster, dtrain, num_parallel_tree, num_classes, num_boost_round)
|
||||
self.run_slice(
|
||||
booster, dtrain, num_parallel_tree, num_classes, num_boost_round, False
|
||||
)
|
||||
|
||||
bytesarray = booster.save_raw(raw_format="deprecated")
|
||||
booster = xgb.Booster(model_file=bytesarray)
|
||||
self.run_slice(booster, dtrain, num_parallel_tree, num_classes, num_boost_round)
|
||||
self.run_slice(
|
||||
booster, dtrain, num_parallel_tree, num_classes, num_boost_round, True
|
||||
)
|
||||
|
||||
def test_slice_multi(self) -> None:
|
||||
from sklearn.datasets import make_classification
|
||||
@@ -479,7 +491,7 @@ class TestModels:
|
||||
},
|
||||
num_boost_round=num_boost_round,
|
||||
dtrain=Xy,
|
||||
callbacks=[ResetStrategy()]
|
||||
callbacks=[ResetStrategy()],
|
||||
)
|
||||
sliced = [t for t in booster]
|
||||
assert len(sliced) == 16
|
||||
|
||||
Reference in New Issue
Block a user