Implement max_cat_threshold for CPU. (#7957)

This commit is contained in:
Jiaming Yuan
2022-06-04 11:02:46 +08:00
committed by GitHub
parent 78694405a6
commit b90c6d25e8
8 changed files with 177 additions and 20 deletions

View File

@@ -74,8 +74,8 @@ class TestGPUUpdaters:
strategies.integers(1, 2), strategies.integers(4, 7))
@settings(deadline=None, print_blob=True)
@pytest.mark.skipif(**tm.no_pandas())
def test_categorical(self, rows, cols, rounds, cats):
self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist")
def test_categorical_ohe(self, rows, cols, rounds, cats):
self.cputest.run_categorical_ohe(rows, cols, rounds, cats, "gpu_hist")
@given(
strategies.integers(10, 400),
@@ -96,7 +96,7 @@ class TestGPUUpdaters:
cols = 10
cats = 32
rounds = 4
self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist")
self.cputest.run_categorical_ohe(rows, cols, rounds, cats, "gpu_hist")
@pytest.mark.skipif(**tm.no_cupy())
def test_invalid_category(self):