Fix pruner. (#5335)
* Honor the tree depth. * Prevent pruning pruned node.
This commit is contained in:
@@ -26,6 +26,30 @@ class TestUpdaters(unittest.TestCase):
|
||||
result = run_suite(param)
|
||||
assert_results_non_increasing(result, 1e-2)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_pruner(self):
|
||||
import sklearn
|
||||
params = {'tree_method': 'exact'}
|
||||
cancer = sklearn.datasets.load_breast_cancer()
|
||||
X = cancer['data']
|
||||
y = cancer["target"]
|
||||
|
||||
dtrain = xgb.DMatrix(X, y)
|
||||
booster = xgb.train(params, dtrain=dtrain, num_boost_round=10)
|
||||
grown = str(booster.get_dump())
|
||||
|
||||
params = {'updater': 'prune', 'process_type': 'update', 'gamma': '0.2'}
|
||||
booster = xgb.train(params, dtrain=dtrain, num_boost_round=10,
|
||||
xgb_model=booster)
|
||||
after_prune = str(booster.get_dump())
|
||||
assert grown != after_prune
|
||||
|
||||
booster = xgb.train(params, dtrain=dtrain, num_boost_round=10,
|
||||
xgb_model=booster)
|
||||
second_prune = str(booster.get_dump())
|
||||
# Second prune should not change the tree
|
||||
assert after_prune == second_prune
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_fast_histmaker(self):
|
||||
variable_param = {'tree_method': ['hist'],
|
||||
|
||||
Reference in New Issue
Block a user