Fix multi-output with alternating strategies. (#9933)

---------

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2024-01-04 16:41:13 +08:00
committed by GitHub
parent 5f7b5a6921
commit 621348abb3
6 changed files with 123 additions and 73 deletions

View File

@@ -12,7 +12,6 @@ from xgboost.testing.params import (
cat_parameter_strategy,
exact_parameter_strategy,
hist_cache_strategy,
hist_multi_parameter_strategy,
hist_parameter_strategy,
)
from xgboost.testing.updater import (
@@ -25,69 +24,6 @@ from xgboost.testing.updater import (
)
class TestTreeMethodMulti:
@given(
exact_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy
)
@settings(deadline=None, print_blob=True)
def test_exact(self, param: dict, num_rounds: int, dataset: tm.TestDataset) -> None:
if dataset.name.endswith("-l1"):
return
param["tree_method"] = "exact"
param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), num_rounds)
assert tm.non_increasing(result["train"][dataset.metric])
@given(
exact_parameter_strategy,
hist_parameter_strategy,
hist_cache_strategy,
strategies.integers(1, 20),
tm.multi_dataset_strategy,
)
@settings(deadline=None, print_blob=True)
def test_approx(
self, param: Dict[str, Any],
hist_param: Dict[str, Any],
cache_param: Dict[str, Any],
num_rounds: int,
dataset: tm.TestDataset,
) -> None:
param["tree_method"] = "approx"
param = dataset.set_params(param)
param.update(hist_param)
param.update(cache_param)
result = train_result(param, dataset.get_dmat(), num_rounds)
note(str(result))
assert tm.non_increasing(result["train"][dataset.metric])
@given(
exact_parameter_strategy,
hist_multi_parameter_strategy,
hist_cache_strategy,
strategies.integers(1, 20),
tm.multi_dataset_strategy,
)
@settings(deadline=None, print_blob=True)
def test_hist(
self,
param: Dict[str, Any],
hist_param: Dict[str, Any],
cache_param: Dict[str, Any],
num_rounds: int,
dataset: tm.TestDataset,
) -> None:
if dataset.name.endswith("-l1"):
return
param["tree_method"] = "hist"
param = dataset.set_params(param)
param.update(hist_param)
param.update(cache_param)
result = train_result(param, dataset.get_dmat(), num_rounds)
note(str(result))
assert tm.non_increasing(result["train"][dataset.metric])
class TestTreeMethod:
USE_ONEHOT = np.iinfo(np.int32).max
USE_PART = 1