Implement max_cat_threshold for CPU. (#7957)
This commit is contained in:
@@ -31,6 +31,14 @@ hist_parameter_strategy = strategies.fixed_dictionaries({
|
||||
x['max_depth'] > 0 or x['grow_policy'] == 'lossguide'))
|
||||
|
||||
|
||||
cat_parameter_strategy = strategies.fixed_dictionaries(
|
||||
{
|
||||
"max_cat_to_onehot": strategies.integers(1, 128),
|
||||
"max_cat_threshold": strategies.integers(1, 128),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def train_result(param, dmat, num_rounds):
|
||||
result = {}
|
||||
xgb.train(param, dmat, num_rounds, [(dmat, 'train')], verbose_eval=False,
|
||||
@@ -253,7 +261,7 @@ class TestTreeMethod:
|
||||
# Test with partition-based split
|
||||
run(self.USE_PART)
|
||||
|
||||
def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
|
||||
def run_categorical_ohe(self, rows, cols, rounds, cats, tree_method):
|
||||
onehot, label = tm.make_categorical(rows, cols, cats, True)
|
||||
cat, _ = tm.make_categorical(rows, cols, cats, False)
|
||||
|
||||
@@ -328,9 +336,55 @@ class TestTreeMethod:
|
||||
strategies.integers(1, 2), strategies.integers(4, 7))
|
||||
@settings(deadline=None, print_blob=True)
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_categorical(self, rows, cols, rounds, cats):
|
||||
self.run_categorical_basic(rows, cols, rounds, cats, "approx")
|
||||
self.run_categorical_basic(rows, cols, rounds, cats, "hist")
|
||||
def test_categorical_ohe(self, rows, cols, rounds, cats):
|
||||
self.run_categorical_ohe(rows, cols, rounds, cats, "approx")
|
||||
self.run_categorical_ohe(rows, cols, rounds, cats, "hist")
|
||||
|
||||
@given(
|
||||
tm.categorical_dataset_strategy,
|
||||
exact_parameter_strategy,
|
||||
hist_parameter_strategy,
|
||||
cat_parameter_strategy,
|
||||
strategies.integers(4, 32),
|
||||
strategies.sampled_from(["hist", "approx"]),
|
||||
)
|
||||
@settings(deadline=None, print_blob=True)
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_categorical(
|
||||
self,
|
||||
dataset: tm.TestDataset,
|
||||
exact_parameters: Dict[str, Any],
|
||||
hist_parameters: Dict[str, Any],
|
||||
cat_parameters: Dict[str, Any],
|
||||
n_rounds: int,
|
||||
tree_method: str,
|
||||
) -> None:
|
||||
cat_parameters.update(exact_parameters)
|
||||
cat_parameters.update(hist_parameters)
|
||||
cat_parameters["tree_method"] = tree_method
|
||||
|
||||
results = train_result(cat_parameters, dataset.get_dmat(), n_rounds)
|
||||
tm.non_increasing(results["train"]["rmse"])
|
||||
|
||||
@given(
|
||||
hist_parameter_strategy,
|
||||
cat_parameter_strategy,
|
||||
strategies.sampled_from(["hist", "approx"]),
|
||||
)
|
||||
@settings(deadline=None, print_blob=True)
|
||||
def test_categorical_ames_housing(
|
||||
self,
|
||||
hist_parameters: Dict[str, Any],
|
||||
cat_parameters: Dict[str, Any],
|
||||
tree_method: str,
|
||||
) -> None:
|
||||
cat_parameters.update(hist_parameters)
|
||||
dataset = tm.TestDataset(
|
||||
"ames_housing", tm.get_ames_housing, "reg:squarederror", "rmse"
|
||||
)
|
||||
cat_parameters["tree_method"] = tree_method
|
||||
results = train_result(cat_parameters, dataset.get_dmat(), 16)
|
||||
tm.non_increasing(results["train"]["rmse"])
|
||||
|
||||
@given(
|
||||
strategies.integers(10, 400),
|
||||
|
||||
Reference in New Issue
Block a user