Fix pruner. (#5335)

* Honor the tree depth.
* Prevent pruning pruned node.
This commit is contained in:
Jiaming Yuan
2020-02-25 08:32:46 +08:00
committed by GitHub
parent b0ed3f0a66
commit e0509b3307
5 changed files with 99 additions and 34 deletions

View File

@@ -26,6 +26,30 @@ class TestUpdaters(unittest.TestCase):
result = run_suite(param)
assert_results_non_increasing(result, 1e-2)
@pytest.mark.skipif(**tm.no_sklearn())
def test_pruner(self):
import sklearn
params = {'tree_method': 'exact'}
cancer = sklearn.datasets.load_breast_cancer()
X = cancer['data']
y = cancer["target"]
dtrain = xgb.DMatrix(X, y)
booster = xgb.train(params, dtrain=dtrain, num_boost_round=10)
grown = str(booster.get_dump())
params = {'updater': 'prune', 'process_type': 'update', 'gamma': '0.2'}
booster = xgb.train(params, dtrain=dtrain, num_boost_round=10,
xgb_model=booster)
after_prune = str(booster.get_dump())
assert grown != after_prune
booster = xgb.train(params, dtrain=dtrain, num_boost_round=10,
xgb_model=booster)
second_prune = str(booster.get_dump())
# Second prune should not change the tree
assert after_prune == second_prune
@pytest.mark.skipif(**tm.no_sklearn())
def test_fast_histmaker(self):
variable_param = {'tree_method': ['hist'],