[backport] Fix best_ntree_limit for dart and gblinear. (#6579) (#6587)

* [backport] Fix `best_ntree_limit` for dart and gblinear. (#6579)

* Backport num group test fix.
This commit is contained in:
Jiaming Yuan 2021-01-11 01:46:05 +08:00 committed by GitHub
parent 7aec915dcd
commit d0ec65520a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 53 additions and 3 deletions

View File

@ -4,6 +4,7 @@
"""Training Library containing training routines.""" """Training Library containing training routines."""
import warnings import warnings
import copy import copy
import json
import numpy as np import numpy as np
from .core import Booster, XGBoostError from .core import Booster, XGBoostError
@ -123,7 +124,28 @@ def _train_internal(params, dtrain,
bst.best_iteration = int(bst.attr('best_iteration')) bst.best_iteration = int(bst.attr('best_iteration'))
else: else:
bst.best_iteration = nboost - 1 bst.best_iteration = nboost - 1
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
config = json.loads(bst.save_config())
booster = config['learner']['gradient_booster']['name']
if booster == 'gblinear':
num_parallel_tree = 0
elif booster == 'dart':
num_parallel_tree = int(
config['learner']['gradient_booster']['gbtree']['gbtree_train_param'][
'num_parallel_tree'
]
)
elif booster == 'gbtree':
num_parallel_tree = int(
config['learner']['gradient_booster']['gbtree_train_param'][
'num_parallel_tree']
)
else:
raise ValueError(f'Unknown booster: {booster}')
num_groups = int(config['learner']['learner_model_param']['num_class'])
num_groups = 1 if num_groups == 0 else num_groups
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree * num_groups
# Copy to serialise and unserialise booster to reset state and free # Copy to serialise and unserialise booster to reset state and free
# training memory # training memory
return bst.copy() return bst.copy()

View File

@ -123,13 +123,13 @@ class TestTrainingContinuation:
gbdt_05 = xgb.train(xgb_params_03, dtrain_5class, gbdt_05 = xgb.train(xgb_params_03, dtrain_5class,
num_boost_round=7) num_boost_round=7)
assert gbdt_05.best_ntree_limit == ( assert gbdt_05.best_ntree_limit == (
gbdt_05.best_iteration + 1) * self.num_parallel_tree gbdt_05.best_iteration + 1) * self.num_parallel_tree * 5
gbdt_05 = xgb.train(xgb_params_03, gbdt_05 = xgb.train(xgb_params_03,
dtrain_5class, dtrain_5class,
num_boost_round=3, num_boost_round=3,
xgb_model=gbdt_05) xgb_model=gbdt_05)
assert gbdt_05.best_ntree_limit == ( assert gbdt_05.best_ntree_limit == (
gbdt_05.best_iteration + 1) * self.num_parallel_tree gbdt_05.best_iteration + 1) * self.num_parallel_tree * 5
res1 = gbdt_05.predict(dtrain_5class) res1 = gbdt_05.predict(dtrain_5class)
res2 = gbdt_05.predict(dtrain_5class, res2 = gbdt_05.predict(dtrain_5class,

View File

@ -78,6 +78,34 @@ def test_multiclass_classification():
check_pred(preds4, labels, output_margin=False) check_pred(preds4, labels, output_margin=False)
def test_best_ntree_limit():
from sklearn.datasets import load_iris
X, y = load_iris(return_X_y=True)
def train(booster, forest):
rounds = 4
cls = xgb.XGBClassifier(
n_estimators=rounds, num_parallel_tree=forest, booster=booster
).fit(
X, y, eval_set=[(X, y)], early_stopping_rounds=3
)
if forest:
assert cls.best_ntree_limit == rounds * forest * cls.n_classes_
else:
assert cls.best_ntree_limit == 0
# best_ntree_limit is used by default, assert that under gblinear it's
# automatically ignored due to being 0.
cls.predict(X)
num_parallel_tree = 4
train('gbtree', num_parallel_tree)
train('dart', num_parallel_tree)
train('gblinear', None)
def test_ranking(): def test_ranking():
# generate random data # generate random data
x_train = np.random.rand(1000, 10) x_train = np.random.rand(1000, 10)