Remove ntree limit in python package. (#8345)

- Remove `ntree_limit`. The parameter has been deprecated since 1.4.0.
- The SHAP package compatibility is broken.
This commit is contained in:
Jiaming Yuan
2023-03-31 19:01:55 +08:00
committed by GitHub
parent b647403baa
commit bac22734fb
17 changed files with 284 additions and 357 deletions

View File

@@ -95,44 +95,39 @@ class TestTrainingContinuation:
res2 = mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class))
assert res1 == res2
gbdt_04 = xgb.train(xgb_params_02, dtrain_2class,
num_boost_round=3)
assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration +
1) * self.num_parallel_tree
gbdt_04 = xgb.train(xgb_params_02, dtrain_2class, num_boost_round=3)
res1 = mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class))
res2 = mean_squared_error(y_2class,
gbdt_04.predict(
dtrain_2class,
ntree_limit=gbdt_04.best_ntree_limit))
res2 = mean_squared_error(
y_2class,
gbdt_04.predict(
dtrain_2class, iteration_range=(0, gbdt_04.best_iteration + 1)
)
)
assert res1 == res2
gbdt_04 = xgb.train(xgb_params_02, dtrain_2class,
num_boost_round=7, xgb_model=gbdt_04)
assert gbdt_04.best_ntree_limit == (
gbdt_04.best_iteration + 1) * self.num_parallel_tree
gbdt_04 = xgb.train(
xgb_params_02, dtrain_2class, num_boost_round=7, xgb_model=gbdt_04
)
res1 = mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class))
res2 = mean_squared_error(y_2class,
gbdt_04.predict(
dtrain_2class,
ntree_limit=gbdt_04.best_ntree_limit))
res2 = mean_squared_error(
y_2class,
gbdt_04.predict(
dtrain_2class, iteration_range=(0, gbdt_04.best_iteration + 1)
)
)
assert res1 == res2
gbdt_05 = xgb.train(xgb_params_03, dtrain_5class,
num_boost_round=7)
assert gbdt_05.best_ntree_limit == (
gbdt_05.best_iteration + 1) * self.num_parallel_tree
gbdt_05 = xgb.train(xgb_params_03,
dtrain_5class,
num_boost_round=3,
xgb_model=gbdt_05)
assert gbdt_05.best_ntree_limit == (
gbdt_05.best_iteration + 1) * self.num_parallel_tree
res1 = gbdt_05.predict(dtrain_5class)
res2 = gbdt_05.predict(dtrain_5class,
ntree_limit=gbdt_05.best_ntree_limit)
res2 = gbdt_05.predict(
dtrain_5class, iteration_range=(0, gbdt_05.best_iteration + 1)
)
np.testing.assert_almost_equal(res1, res2)
@pytest.mark.skipif(**tm.no_sklearn())