More categorical tests and disable shap sparse test. (#6219)
* Fix tree load with 32 category.
This commit is contained in:
@@ -41,7 +41,24 @@ class TestGPUUpdaters:
|
||||
note(result)
|
||||
assert tm.non_increasing(result['train'][dataset.metric])
|
||||
|
||||
def run_categorical_basic(self, cat, onehot, label, rounds):
|
||||
def run_categorical_basic(self, rows, cols, rounds, cats):
|
||||
import pandas as pd
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
pd_dict = {}
|
||||
for i in range(cols):
|
||||
c = rng.randint(low=0, high=cats+1, size=rows)
|
||||
pd_dict[str(i)] = pd.Series(c, dtype=np.int64)
|
||||
|
||||
df = pd.DataFrame(pd_dict)
|
||||
label = df.iloc[:, 0]
|
||||
for i in range(0, cols-1):
|
||||
label += df.iloc[:, i]
|
||||
label += 1
|
||||
df = df.astype('category')
|
||||
onehot = pd.get_dummies(df)
|
||||
cat = df
|
||||
|
||||
by_etl_results = {}
|
||||
by_builtin_results = {}
|
||||
|
||||
@@ -64,28 +81,20 @@ class TestGPUUpdaters:
|
||||
rtol=1e-3)
|
||||
assert tm.non_increasing(by_builtin_results['Train']['rmse'])
|
||||
|
||||
@given(strategies.integers(10, 400), strategies.integers(5, 10),
|
||||
strategies.integers(1, 5), strategies.integers(4, 8))
|
||||
@given(strategies.integers(10, 400), strategies.integers(3, 8),
|
||||
strategies.integers(1, 5), strategies.integers(4, 7))
|
||||
@settings(deadline=None)
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_categorical(self, rows, cols, rounds, cats):
|
||||
import pandas as pd
|
||||
rng = np.random.RandomState(1994)
|
||||
self.run_categorical_basic(rows, cols, rounds, cats)
|
||||
|
||||
pd_dict = {}
|
||||
for i in range(cols):
|
||||
c = rng.randint(low=0, high=cats+1, size=rows)
|
||||
pd_dict[str(i)] = pd.Series(c, dtype=np.int64)
|
||||
|
||||
df = pd.DataFrame(pd_dict)
|
||||
label = df.iloc[:, 0]
|
||||
for i in range(0, cols-1):
|
||||
label += df.iloc[:, i]
|
||||
label += 1
|
||||
df = df.astype('category')
|
||||
x = pd.get_dummies(df)
|
||||
|
||||
self.run_categorical_basic(df, x, label, rounds)
|
||||
def test_categorical_32_cat(self):
|
||||
'''32 hits the bound of integer bitset, so special test'''
|
||||
rows = 1000
|
||||
cols = 10
|
||||
cats = 32
|
||||
rounds = 4
|
||||
self.run_categorical_basic(rows, cols, rounds, cats)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@given(parameter_strategy, strategies.integers(1, 20),
|
||||
|
||||
Reference in New Issue
Block a user