Remove ntree limit in python package. (#8345)
- Remove `ntree_limit`. The parameter has been deprecated since 1.4.0. - The SHAP package compatibility is broken.
This commit is contained in:
@@ -95,44 +95,39 @@ class TestTrainingContinuation:
|
||||
res2 = mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class))
|
||||
assert res1 == res2
|
||||
|
||||
gbdt_04 = xgb.train(xgb_params_02, dtrain_2class,
|
||||
num_boost_round=3)
|
||||
assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration +
|
||||
1) * self.num_parallel_tree
|
||||
|
||||
gbdt_04 = xgb.train(xgb_params_02, dtrain_2class, num_boost_round=3)
|
||||
res1 = mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class))
|
||||
res2 = mean_squared_error(y_2class,
|
||||
gbdt_04.predict(
|
||||
dtrain_2class,
|
||||
ntree_limit=gbdt_04.best_ntree_limit))
|
||||
res2 = mean_squared_error(
|
||||
y_2class,
|
||||
gbdt_04.predict(
|
||||
dtrain_2class, iteration_range=(0, gbdt_04.best_iteration + 1)
|
||||
)
|
||||
)
|
||||
assert res1 == res2
|
||||
|
||||
gbdt_04 = xgb.train(xgb_params_02, dtrain_2class,
|
||||
num_boost_round=7, xgb_model=gbdt_04)
|
||||
assert gbdt_04.best_ntree_limit == (
|
||||
gbdt_04.best_iteration + 1) * self.num_parallel_tree
|
||||
|
||||
gbdt_04 = xgb.train(
|
||||
xgb_params_02, dtrain_2class, num_boost_round=7, xgb_model=gbdt_04
|
||||
)
|
||||
res1 = mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class))
|
||||
res2 = mean_squared_error(y_2class,
|
||||
gbdt_04.predict(
|
||||
dtrain_2class,
|
||||
ntree_limit=gbdt_04.best_ntree_limit))
|
||||
res2 = mean_squared_error(
|
||||
y_2class,
|
||||
gbdt_04.predict(
|
||||
dtrain_2class, iteration_range=(0, gbdt_04.best_iteration + 1)
|
||||
)
|
||||
)
|
||||
assert res1 == res2
|
||||
|
||||
gbdt_05 = xgb.train(xgb_params_03, dtrain_5class,
|
||||
num_boost_round=7)
|
||||
assert gbdt_05.best_ntree_limit == (
|
||||
gbdt_05.best_iteration + 1) * self.num_parallel_tree
|
||||
gbdt_05 = xgb.train(xgb_params_03,
|
||||
dtrain_5class,
|
||||
num_boost_round=3,
|
||||
xgb_model=gbdt_05)
|
||||
assert gbdt_05.best_ntree_limit == (
|
||||
gbdt_05.best_iteration + 1) * self.num_parallel_tree
|
||||
|
||||
res1 = gbdt_05.predict(dtrain_5class)
|
||||
res2 = gbdt_05.predict(dtrain_5class,
|
||||
ntree_limit=gbdt_05.best_ntree_limit)
|
||||
res2 = gbdt_05.predict(
|
||||
dtrain_5class, iteration_range=(0, gbdt_05.best_iteration + 1)
|
||||
)
|
||||
np.testing.assert_almost_equal(res1, res2)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
|
||||
Reference in New Issue
Block a user