Fix best_ntree_limit. (#6569)

This commit is contained in:
Jiaming Yuan 2021-01-03 05:58:54 +08:00 committed by GitHub
parent 195a41cef1
commit 516a93d25c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 4 deletions

View File

@ -93,12 +93,16 @@ def _train_internal(params, dtrain,
bst.best_iteration = int(bst.attr('best_iteration'))
else:
bst.best_iteration = bst.num_boosted_rounds() - 1
config = json.loads(bst.save_config())
try:
num_parallel_tree = int(json.loads(bst.save_config())['learner'][
num_parallel_tree = int(config['learner'][
'gradient_booster']['gbtree_train_param']['num_parallel_tree'])
except KeyError: # gblinear
num_parallel_tree = 1
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
num_groups = int(config['learner']['learner_model_param']['num_class'])
num_groups = 1 if num_groups == 0 else num_groups
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree * num_groups
# Copy to serialise and unserialise booster to reset state and free
# training memory
return bst.copy()

View File

@ -119,13 +119,13 @@ class TestTrainingContinuation:
gbdt_05 = xgb.train(xgb_params_03, dtrain_5class,
num_boost_round=7)
assert gbdt_05.best_ntree_limit == (
gbdt_05.best_iteration + 1) * self.num_parallel_tree
gbdt_05.best_iteration + 1) * self.num_parallel_tree * 5
gbdt_05 = xgb.train(xgb_params_03,
dtrain_5class,
num_boost_round=3,
xgb_model=gbdt_05)
assert gbdt_05.best_ntree_limit == (
gbdt_05.best_iteration + 1) * self.num_parallel_tree
gbdt_05.best_iteration + 1) * self.num_parallel_tree * 5
res1 = gbdt_05.predict(dtrain_5class)
res2 = gbdt_05.predict(dtrain_5class,