* [backport] Fix `best_ntree_limit` for dart and gblinear. (#6579) * Backport num group test fix.
This commit is contained in:
@@ -123,13 +123,13 @@ class TestTrainingContinuation:
|
||||
gbdt_05 = xgb.train(xgb_params_03, dtrain_5class,
|
||||
num_boost_round=7)
|
||||
assert gbdt_05.best_ntree_limit == (
|
||||
gbdt_05.best_iteration + 1) * self.num_parallel_tree
|
||||
gbdt_05.best_iteration + 1) * self.num_parallel_tree * 5
|
||||
gbdt_05 = xgb.train(xgb_params_03,
|
||||
dtrain_5class,
|
||||
num_boost_round=3,
|
||||
xgb_model=gbdt_05)
|
||||
assert gbdt_05.best_ntree_limit == (
|
||||
gbdt_05.best_iteration + 1) * self.num_parallel_tree
|
||||
gbdt_05.best_iteration + 1) * self.num_parallel_tree * 5
|
||||
|
||||
res1 = gbdt_05.predict(dtrain_5class)
|
||||
res2 = gbdt_05.predict(dtrain_5class,
|
||||
|
||||
@@ -78,6 +78,34 @@ def test_multiclass_classification():
|
||||
check_pred(preds4, labels, output_margin=False)
|
||||
|
||||
|
||||
def test_best_ntree_limit():
|
||||
from sklearn.datasets import load_iris
|
||||
|
||||
X, y = load_iris(return_X_y=True)
|
||||
|
||||
def train(booster, forest):
|
||||
rounds = 4
|
||||
cls = xgb.XGBClassifier(
|
||||
n_estimators=rounds, num_parallel_tree=forest, booster=booster
|
||||
).fit(
|
||||
X, y, eval_set=[(X, y)], early_stopping_rounds=3
|
||||
)
|
||||
|
||||
if forest:
|
||||
assert cls.best_ntree_limit == rounds * forest * cls.n_classes_
|
||||
else:
|
||||
assert cls.best_ntree_limit == 0
|
||||
|
||||
# best_ntree_limit is used by default, assert that under gblinear it's
|
||||
# automatically ignored due to being 0.
|
||||
cls.predict(X)
|
||||
|
||||
num_parallel_tree = 4
|
||||
train('gbtree', num_parallel_tree)
|
||||
train('dart', num_parallel_tree)
|
||||
train('gblinear', None)
|
||||
|
||||
|
||||
def test_ranking():
|
||||
# generate random data
|
||||
x_train = np.random.rand(1000, 10)
|
||||
|
||||
Reference in New Issue
Block a user