Unify evaluation functions. (#6037)
This commit is contained in:
@@ -198,6 +198,9 @@ class TestGPUPredict(unittest.TestCase):
|
||||
tm.dataset_strategy, shap_parameter_strategy, strategies.booleans())
|
||||
@settings(deadline=None)
|
||||
def test_shap(self, num_rounds, dataset, param, all_rows):
|
||||
if param['max_depth'] == 0 and param['max_leaves'] == 0:
|
||||
return
|
||||
|
||||
param.update({"predictor": "gpu_predictor", "gpu_id": 0})
|
||||
param = dataset.set_params(param)
|
||||
dmat = dataset.get_dmat()
|
||||
|
||||
@@ -7,6 +7,7 @@ import pytest
|
||||
import xgboost as xgb
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
import test_monotone_constraints as tmc
|
||||
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
@@ -30,6 +31,7 @@ def assert_constraint(constraint, tree_method):
|
||||
bst = xgb.train(param, dtrain)
|
||||
dpredict = xgb.DMatrix(X[X[:, 0].argsort()])
|
||||
pred = bst.predict(dpredict)
|
||||
|
||||
if constraint > 0:
|
||||
assert non_decreasing(pred)
|
||||
elif constraint < 0:
|
||||
@@ -38,11 +40,24 @@ def assert_constraint(constraint, tree_method):
|
||||
|
||||
class TestMonotonicConstraints(unittest.TestCase):
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_exact(self):
|
||||
assert_constraint(1, 'exact')
|
||||
assert_constraint(-1, 'exact')
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_gpu_hist(self):
|
||||
def test_gpu_hist_basic(self):
|
||||
assert_constraint(1, 'gpu_hist')
|
||||
assert_constraint(-1, 'gpu_hist')
|
||||
|
||||
def test_gpu_hist_depthwise(self):
|
||||
params = {
|
||||
'tree_method': 'gpu_hist',
|
||||
'grow_policy': 'depthwise',
|
||||
'monotone_constraints': '(1, -1)'
|
||||
}
|
||||
model = xgb.train(params, tmc.training_dset)
|
||||
tmc.is_correctly_constrained(model)
|
||||
|
||||
def test_gpu_hist_lossguide(self):
|
||||
params = {
|
||||
'tree_method': 'gpu_hist',
|
||||
'grow_policy': 'lossguide',
|
||||
'monotone_constraints': '(1, -1)'
|
||||
}
|
||||
model = xgb.train(params, tmc.training_dset)
|
||||
tmc.is_correctly_constrained(model)
|
||||
|
||||
Reference in New Issue
Block a user