Remove ntree limit in python package. (#8345)
- Remove `ntree_limit`. The parameter has been deprecated since 1.4.0. - The SHAP package compatibility is broken.
This commit is contained in:
@@ -63,9 +63,15 @@ def test_multiclass_classification(objective):
|
||||
assert xgb_model.get_booster().num_boosted_rounds() == 100
|
||||
preds = xgb_model.predict(X[test_index])
|
||||
# test other params in XGBClassifier().fit
|
||||
preds2 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=3)
|
||||
preds3 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=0)
|
||||
preds4 = xgb_model.predict(X[test_index], output_margin=False, ntree_limit=3)
|
||||
preds2 = xgb_model.predict(
|
||||
X[test_index], output_margin=True, iteration_range=(0, 1)
|
||||
)
|
||||
preds3 = xgb_model.predict(
|
||||
X[test_index], output_margin=True, iteration_range=None
|
||||
)
|
||||
preds4 = xgb_model.predict(
|
||||
X[test_index], output_margin=False, iteration_range=(0, 1)
|
||||
)
|
||||
labels = y[test_index]
|
||||
|
||||
check_pred(preds, labels, output_margin=False)
|
||||
@@ -86,25 +92,21 @@ def test_multiclass_classification(objective):
|
||||
assert proba.shape[1] == cls.n_classes_
|
||||
|
||||
|
||||
def test_best_ntree_limit():
|
||||
def test_best_iteration():
|
||||
from sklearn.datasets import load_iris
|
||||
|
||||
X, y = load_iris(return_X_y=True)
|
||||
|
||||
def train(booster, forest):
|
||||
def train(booster: str, forest: Optional[int]) -> None:
|
||||
rounds = 4
|
||||
cls = xgb.XGBClassifier(
|
||||
n_estimators=rounds, num_parallel_tree=forest, booster=booster
|
||||
).fit(
|
||||
X, y, eval_set=[(X, y)], early_stopping_rounds=3
|
||||
)
|
||||
assert cls.best_iteration == rounds - 1
|
||||
|
||||
if forest:
|
||||
assert cls.best_ntree_limit == rounds * forest
|
||||
else:
|
||||
assert cls.best_ntree_limit == 0
|
||||
|
||||
# best_ntree_limit is used by default, assert that under gblinear it's
|
||||
# best_iteration is used by default, assert that under gblinear it's
|
||||
# automatically ignored due to being 0.
|
||||
cls.predict(X)
|
||||
|
||||
@@ -430,12 +432,15 @@ def test_regression():
|
||||
|
||||
preds = xgb_model.predict(X[test_index])
|
||||
# test other params in XGBRegressor().fit
|
||||
preds2 = xgb_model.predict(X[test_index], output_margin=True,
|
||||
ntree_limit=3)
|
||||
preds3 = xgb_model.predict(X[test_index], output_margin=True,
|
||||
ntree_limit=0)
|
||||
preds4 = xgb_model.predict(X[test_index], output_margin=False,
|
||||
ntree_limit=3)
|
||||
preds2 = xgb_model.predict(
|
||||
X[test_index], output_margin=True, iteration_range=(0, 3)
|
||||
)
|
||||
preds3 = xgb_model.predict(
|
||||
X[test_index], output_margin=True, iteration_range=None
|
||||
)
|
||||
preds4 = xgb_model.predict(
|
||||
X[test_index], output_margin=False, iteration_range=(0, 3)
|
||||
)
|
||||
labels = y[test_index]
|
||||
|
||||
assert mean_squared_error(preds, labels) < 25
|
||||
|
||||
Reference in New Issue
Block a user