[MT-TREE] Support prediction cache and model slicing. (#8968)

- Fix prediction range.
- Support prediction cache in mt-hist.
- Support model slicing.
- Make the booster a Python iterable by defining `__iter__`.
- Cleanup removed/deprecated parameters.
- A new field in the output model `iteration_indptr` for pointing to the ranges of trees for each iteration.
This commit is contained in:
Jiaming Yuan
2023-03-27 23:10:54 +08:00
committed by GitHub
parent c2b3a13e70
commit acc110c251
30 changed files with 502 additions and 343 deletions

View File

@@ -524,7 +524,7 @@ class TestModels:
booster[-1:0]
# we do not accept empty slice.
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="Empty slice"):
booster[1:1]
# stop can not be smaller than begin
with pytest.raises(ValueError, match=r"Invalid.*"):
@@ -615,6 +615,46 @@ class TestModels:
booster = xgb.Booster(model_file=bytesarray)
self.run_slice(booster, dtrain, num_parallel_tree, num_classes, num_boost_round)
def test_slice_multi(self) -> None:
from sklearn.datasets import make_classification
num_classes = 3
X, y = make_classification(
n_samples=1000, n_informative=5, n_classes=num_classes
)
Xy = xgb.DMatrix(data=X, label=y)
num_parallel_tree = 4
num_boost_round = 16
class ResetStrategy(xgb.callback.TrainingCallback):
def after_iteration(self, model, epoch: int, evals_log) -> bool:
model.set_param({"multi_strategy": "multi_output_tree"})
return False
booster = xgb.train(
{
"num_parallel_tree": num_parallel_tree,
"num_class": num_classes,
"booster": "gbtree",
"objective": "multi:softprob",
"multi_strategy": "multi_output_tree",
"tree_method": "hist",
"base_score": 0,
},
num_boost_round=num_boost_round,
dtrain=Xy,
callbacks=[ResetStrategy()]
)
sliced = [t for t in booster]
assert len(sliced) == 16
predt0 = booster.predict(Xy, output_margin=True)
predt1 = np.zeros(predt0.shape)
for t in booster:
predt1 += t.predict(Xy, output_margin=True)
np.testing.assert_allclose(predt0, predt1, atol=1e-5)
@pytest.mark.skipif(**tm.no_pandas())
def test_feature_info(self):
import pandas as pd