Implement max_cat_threshold for CPU. (#7957)

This commit is contained in:
Jiaming Yuan
2022-06-04 11:02:46 +08:00
committed by GitHub
parent 78694405a6
commit b90c6d25e8
8 changed files with 177 additions and 20 deletions

View File

@@ -31,6 +31,14 @@ hist_parameter_strategy = strategies.fixed_dictionaries({
x['max_depth'] > 0 or x['grow_policy'] == 'lossguide'))
cat_parameter_strategy = strategies.fixed_dictionaries(
{
"max_cat_to_onehot": strategies.integers(1, 128),
"max_cat_threshold": strategies.integers(1, 128),
}
)
def train_result(param, dmat, num_rounds):
result = {}
xgb.train(param, dmat, num_rounds, [(dmat, 'train')], verbose_eval=False,
@@ -253,7 +261,7 @@ class TestTreeMethod:
# Test with partition-based split
run(self.USE_PART)
def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
def run_categorical_ohe(self, rows, cols, rounds, cats, tree_method):
onehot, label = tm.make_categorical(rows, cols, cats, True)
cat, _ = tm.make_categorical(rows, cols, cats, False)
@@ -328,9 +336,55 @@ class TestTreeMethod:
strategies.integers(1, 2), strategies.integers(4, 7))
@settings(deadline=None, print_blob=True)
@pytest.mark.skipif(**tm.no_pandas())
def test_categorical(self, rows, cols, rounds, cats):
self.run_categorical_basic(rows, cols, rounds, cats, "approx")
self.run_categorical_basic(rows, cols, rounds, cats, "hist")
def test_categorical_ohe(self, rows, cols, rounds, cats):
self.run_categorical_ohe(rows, cols, rounds, cats, "approx")
self.run_categorical_ohe(rows, cols, rounds, cats, "hist")
@given(
tm.categorical_dataset_strategy,
exact_parameter_strategy,
hist_parameter_strategy,
cat_parameter_strategy,
strategies.integers(4, 32),
strategies.sampled_from(["hist", "approx"]),
)
@settings(deadline=None, print_blob=True)
@pytest.mark.skipif(**tm.no_pandas())
def test_categorical(
self,
dataset: tm.TestDataset,
exact_parameters: Dict[str, Any],
hist_parameters: Dict[str, Any],
cat_parameters: Dict[str, Any],
n_rounds: int,
tree_method: str,
) -> None:
cat_parameters.update(exact_parameters)
cat_parameters.update(hist_parameters)
cat_parameters["tree_method"] = tree_method
results = train_result(cat_parameters, dataset.get_dmat(), n_rounds)
tm.non_increasing(results["train"]["rmse"])
@given(
hist_parameter_strategy,
cat_parameter_strategy,
strategies.sampled_from(["hist", "approx"]),
)
@settings(deadline=None, print_blob=True)
def test_categorical_ames_housing(
self,
hist_parameters: Dict[str, Any],
cat_parameters: Dict[str, Any],
tree_method: str,
) -> None:
cat_parameters.update(hist_parameters)
dataset = tm.TestDataset(
"ames_housing", tm.get_ames_housing, "reg:squarederror", "rmse"
)
cat_parameters["tree_method"] = tree_method
results = train_result(cat_parameters, dataset.get_dmat(), 16)
tm.non_increasing(results["train"]["rmse"])
@given(
strategies.integers(10, 400),