Bound the size of the histogram cache. (#9440)
- A new histogram collection with a limit in size. - Unify histogram building logic between hist, multi-hist, and approx.
This commit is contained in:
@@ -24,7 +24,7 @@ from sklearn.datasets import make_classification, make_regression
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
from xgboost.data import _is_cudf_df
|
||||
from xgboost.testing.params import hist_parameter_strategy
|
||||
from xgboost.testing.params import hist_cache_strategy, hist_parameter_strategy
|
||||
from xgboost.testing.shared import (
|
||||
get_feature_weights,
|
||||
validate_data_initialization,
|
||||
@@ -1512,14 +1512,23 @@ class TestWithDask:
|
||||
else:
|
||||
assert history[-1] < history[0]
|
||||
|
||||
@given(params=hist_parameter_strategy, dataset=tm.make_dataset_strategy())
|
||||
@given(
|
||||
params=hist_parameter_strategy,
|
||||
cache_param=hist_cache_strategy,
|
||||
dataset=tm.make_dataset_strategy(),
|
||||
)
|
||||
@settings(
|
||||
deadline=None, max_examples=10, suppress_health_check=suppress, print_blob=True
|
||||
)
|
||||
def test_hist(
|
||||
self, params: Dict, dataset: tm.TestDataset, client: "Client"
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
cache_param: Dict[str, Any],
|
||||
dataset: tm.TestDataset,
|
||||
client: "Client",
|
||||
) -> None:
|
||||
num_rounds = 10
|
||||
params.update(cache_param)
|
||||
self.run_updater_test(client, params, num_rounds, dataset, "hist")
|
||||
|
||||
def test_quantile_dmatrix(self, client: Client) -> None:
|
||||
@@ -1579,14 +1588,23 @@ class TestWithDask:
|
||||
rmse = result["history"]["Valid"]["rmse"][-1]
|
||||
assert rmse < 32.0
|
||||
|
||||
@given(params=hist_parameter_strategy, dataset=tm.make_dataset_strategy())
|
||||
@given(
|
||||
params=hist_parameter_strategy,
|
||||
cache_param=hist_cache_strategy,
|
||||
dataset=tm.make_dataset_strategy()
|
||||
)
|
||||
@settings(
|
||||
deadline=None, max_examples=10, suppress_health_check=suppress, print_blob=True
|
||||
)
|
||||
def test_approx(
|
||||
self, client: "Client", params: Dict, dataset: tm.TestDataset
|
||||
self,
|
||||
client: "Client",
|
||||
params: Dict,
|
||||
cache_param: Dict[str, Any],
|
||||
dataset: tm.TestDataset,
|
||||
) -> None:
|
||||
num_rounds = 10
|
||||
params.update(cache_param)
|
||||
self.run_updater_test(client, params, num_rounds, dataset, "approx")
|
||||
|
||||
def test_adaptive(self) -> None:
|
||||
@@ -2239,7 +2257,7 @@ async def test_worker_left(c, s, a, b):
|
||||
)
|
||||
await async_poll_for(lambda: len(s.workers) == 2, timeout=5)
|
||||
with pytest.raises(RuntimeError, match="Missing"):
|
||||
await xgb.dask.train(
|
||||
await xgb.dask.train(
|
||||
c,
|
||||
{},
|
||||
d_train,
|
||||
@@ -2256,7 +2274,7 @@ async def test_worker_restarted(c, s, a, b):
|
||||
)
|
||||
await c.restart_workers([a.worker_address])
|
||||
with pytest.raises(RuntimeError, match="Missing"):
|
||||
await xgb.dask.train(
|
||||
await xgb.dask.train(
|
||||
c,
|
||||
{},
|
||||
d_train,
|
||||
|
||||
Reference in New Issue
Block a user