Bound the size of the histogram cache. (#9440)

- A new histogram collection with a limit in size.
- Unify histogram building logic between hist, multi-hist, and approx.
This commit is contained in:
Jiaming Yuan
2023-08-08 03:21:26 +08:00
committed by GitHub
parent 5bd163aa25
commit 54029a59af
27 changed files with 994 additions and 565 deletions

View File

@@ -24,7 +24,7 @@ from sklearn.datasets import make_classification, make_regression
import xgboost as xgb
from xgboost import testing as tm
from xgboost.data import _is_cudf_df
from xgboost.testing.params import hist_parameter_strategy
from xgboost.testing.params import hist_cache_strategy, hist_parameter_strategy
from xgboost.testing.shared import (
get_feature_weights,
validate_data_initialization,
@@ -1512,14 +1512,23 @@ class TestWithDask:
else:
assert history[-1] < history[0]
@given(params=hist_parameter_strategy, dataset=tm.make_dataset_strategy())
@given(
params=hist_parameter_strategy,
cache_param=hist_cache_strategy,
dataset=tm.make_dataset_strategy(),
)
@settings(
deadline=None, max_examples=10, suppress_health_check=suppress, print_blob=True
)
def test_hist(
self, params: Dict, dataset: tm.TestDataset, client: "Client"
self,
params: Dict[str, Any],
cache_param: Dict[str, Any],
dataset: tm.TestDataset,
client: "Client",
) -> None:
num_rounds = 10
params.update(cache_param)
self.run_updater_test(client, params, num_rounds, dataset, "hist")
def test_quantile_dmatrix(self, client: Client) -> None:
@@ -1579,14 +1588,23 @@ class TestWithDask:
rmse = result["history"]["Valid"]["rmse"][-1]
assert rmse < 32.0
@given(params=hist_parameter_strategy, dataset=tm.make_dataset_strategy())
@given(
params=hist_parameter_strategy,
cache_param=hist_cache_strategy,
dataset=tm.make_dataset_strategy()
)
@settings(
deadline=None, max_examples=10, suppress_health_check=suppress, print_blob=True
)
def test_approx(
self, client: "Client", params: Dict, dataset: tm.TestDataset
self,
client: "Client",
params: Dict,
cache_param: Dict[str, Any],
dataset: tm.TestDataset,
) -> None:
num_rounds = 10
params.update(cache_param)
self.run_updater_test(client, params, num_rounds, dataset, "approx")
def test_adaptive(self) -> None:
@@ -2239,7 +2257,7 @@ async def test_worker_left(c, s, a, b):
)
await async_poll_for(lambda: len(s.workers) == 2, timeout=5)
with pytest.raises(RuntimeError, match="Missing"):
await xgb.dask.train(
await xgb.dask.train(
c,
{},
d_train,
@@ -2256,7 +2274,7 @@ async def test_worker_restarted(c, s, a, b):
)
await c.restart_workers([a.worker_address])
with pytest.raises(RuntimeError, match="Missing"):
await xgb.dask.train(
await xgb.dask.train(
c,
{},
d_train,