Fix best_ntree_limit for dart and gblinear. (#6579)

This commit is contained in:
Jiaming Yuan 2021-01-08 10:05:39 +08:00 committed by GitHub
parent f5ff90cd87
commit 7c9dcbedbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 7 deletions

View File

@ -93,16 +93,28 @@ def _train_internal(params, dtrain,
bst.best_iteration = int(bst.attr('best_iteration'))
else:
bst.best_iteration = bst.num_boosted_rounds() - 1
config = json.loads(bst.save_config())
try:
num_parallel_tree = int(config['learner'][
'gradient_booster']['gbtree_train_param']['num_parallel_tree'])
except KeyError: # gblinear
num_parallel_tree = 1
config = json.loads(bst.save_config())
booster = config['learner']['gradient_booster']['name']
if booster == 'gblinear':
num_parallel_tree = 0
elif booster == 'dart':
num_parallel_tree = int(
config['learner']['gradient_booster']['gbtree']['gbtree_train_param'][
'num_parallel_tree'
]
)
elif booster == 'gbtree':
num_parallel_tree = int(
config['learner']['gradient_booster']['gbtree_train_param'][
'num_parallel_tree']
)
else:
raise ValueError(f'Unknown booster: {booster}')
num_groups = int(config['learner']['learner_model_param']['num_class'])
num_groups = 1 if num_groups == 0 else num_groups
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree * num_groups
bst.best_ntree_limit = ((bst.best_iteration + 1) * num_parallel_tree * num_groups)
# Copy to serialise and unserialise booster to reset state and free
# training memory
return bst.copy()

View File

@ -92,6 +92,34 @@ def test_multiclass_classification():
assert proba.shape[1] == cls.n_classes_
def test_best_ntree_limit():
from sklearn.datasets import load_iris
X, y = load_iris(return_X_y=True)
def train(booster, forest):
rounds = 4
cls = xgb.XGBClassifier(
n_estimators=rounds, num_parallel_tree=forest, booster=booster
).fit(
X, y, eval_set=[(X, y)], early_stopping_rounds=3
)
if forest:
assert cls.best_ntree_limit == rounds * forest * cls.n_classes_
else:
assert cls.best_ntree_limit == 0
# best_ntree_limit is used by default, assert that under gblinear it's
# automatically ignored due to being 0.
cls.predict(X)
num_parallel_tree = 4
train('gbtree', num_parallel_tree)
train('dart', num_parallel_tree)
train('gblinear', None)
def test_ranking():
# generate random data
x_train = np.random.rand(1000, 10)