Add tests for prediction cache. (#7650)

* Extract the test from approx for other tree methods.
* Add note on how it works.
This commit is contained in:
Jiaming Yuan
2022-02-15 00:28:00 +08:00
committed by GitHub
parent 5cd1f71b51
commit 2369d55e9a
5 changed files with 134 additions and 57 deletions

View File

@@ -26,10 +26,19 @@ parameter_strategy = strategies.fixed_dictionaries({
x['max_depth'] > 0 or x['grow_policy'] == 'lossguide'))
def train_result(param, dmat, num_rounds):
result = {}
xgb.train(param, dmat, num_rounds, [(dmat, 'train')], verbose_eval=False,
evals_result=result)
def train_result(param, dmat: xgb.DMatrix, num_rounds: int) -> dict:
result: xgb.callback.TrainingCallback.EvalsLog = {}
booster = xgb.train(
param,
dmat,
num_rounds,
[(dmat, "train")],
verbose_eval=False,
evals_result=result,
)
assert booster.num_features() == dmat.num_col()
assert booster.num_boosted_rounds() == num_rounds
return result