[backport] Fix best_ntree_limit for dart and gblinear. (#6579) (#6587)

* [backport] Fix `best_ntree_limit` for dart and gblinear. (#6579)

* Backport num group test fix.
This commit is contained in:
Jiaming Yuan
2021-01-11 01:46:05 +08:00
committed by GitHub
parent 7aec915dcd
commit d0ec65520a
3 changed files with 53 additions and 3 deletions

View File

@@ -4,6 +4,7 @@
"""Training Library containing training routines."""
import warnings
import copy
import json
import numpy as np
from .core import Booster, XGBoostError
@@ -123,7 +124,28 @@ def _train_internal(params, dtrain,
bst.best_iteration = int(bst.attr('best_iteration'))
else:
bst.best_iteration = nboost - 1
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
config = json.loads(bst.save_config())
booster = config['learner']['gradient_booster']['name']
if booster == 'gblinear':
num_parallel_tree = 0
elif booster == 'dart':
num_parallel_tree = int(
config['learner']['gradient_booster']['gbtree']['gbtree_train_param'][
'num_parallel_tree'
]
)
elif booster == 'gbtree':
num_parallel_tree = int(
config['learner']['gradient_booster']['gbtree_train_param'][
'num_parallel_tree']
)
else:
raise ValueError(f'Unknown booster: {booster}')
num_groups = int(config['learner']['learner_model_param']['num_class'])
num_groups = 1 if num_groups == 0 else num_groups
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree * num_groups
# Copy to serialise and unserialise booster to reset state and free
# training memory
return bst.copy()