[MT-TREE] Support prediction cache and model slicing. (#8968)
- Fix prediction range. - Support prediction cache in mt-hist. - Support model slicing. - Make the booster a Python iterable by defining `__iter__`. - Cleanup removed/deprecated parameters. - A new field in the output model `iteration_indptr` for pointing to the ranges of trees for each iteration.
This commit is contained in:
@@ -524,7 +524,7 @@ class TestModels:
|
||||
booster[-1:0]
|
||||
|
||||
# we do not accept empty slice.
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="Empty slice"):
|
||||
booster[1:1]
|
||||
# stop can not be smaller than begin
|
||||
with pytest.raises(ValueError, match=r"Invalid.*"):
|
||||
@@ -615,6 +615,46 @@ class TestModels:
|
||||
booster = xgb.Booster(model_file=bytesarray)
|
||||
self.run_slice(booster, dtrain, num_parallel_tree, num_classes, num_boost_round)
|
||||
|
||||
def test_slice_multi(self) -> None:
|
||||
from sklearn.datasets import make_classification
|
||||
|
||||
num_classes = 3
|
||||
X, y = make_classification(
|
||||
n_samples=1000, n_informative=5, n_classes=num_classes
|
||||
)
|
||||
Xy = xgb.DMatrix(data=X, label=y)
|
||||
num_parallel_tree = 4
|
||||
num_boost_round = 16
|
||||
|
||||
class ResetStrategy(xgb.callback.TrainingCallback):
|
||||
def after_iteration(self, model, epoch: int, evals_log) -> bool:
|
||||
model.set_param({"multi_strategy": "multi_output_tree"})
|
||||
return False
|
||||
|
||||
booster = xgb.train(
|
||||
{
|
||||
"num_parallel_tree": num_parallel_tree,
|
||||
"num_class": num_classes,
|
||||
"booster": "gbtree",
|
||||
"objective": "multi:softprob",
|
||||
"multi_strategy": "multi_output_tree",
|
||||
"tree_method": "hist",
|
||||
"base_score": 0,
|
||||
},
|
||||
num_boost_round=num_boost_round,
|
||||
dtrain=Xy,
|
||||
callbacks=[ResetStrategy()]
|
||||
)
|
||||
sliced = [t for t in booster]
|
||||
assert len(sliced) == 16
|
||||
|
||||
predt0 = booster.predict(Xy, output_margin=True)
|
||||
predt1 = np.zeros(predt0.shape)
|
||||
for t in booster:
|
||||
predt1 += t.predict(Xy, output_margin=True)
|
||||
|
||||
np.testing.assert_allclose(predt0, predt1, atol=1e-5)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_feature_info(self):
|
||||
import pandas as pd
|
||||
|
||||
Reference in New Issue
Block a user