Fix best_ntree_limit for dart and gblinear. (#6579)
This commit is contained in:
@@ -93,16 +93,28 @@ def _train_internal(params, dtrain,
|
||||
bst.best_iteration = int(bst.attr('best_iteration'))
|
||||
else:
|
||||
bst.best_iteration = bst.num_boosted_rounds() - 1
|
||||
config = json.loads(bst.save_config())
|
||||
try:
|
||||
num_parallel_tree = int(config['learner'][
|
||||
'gradient_booster']['gbtree_train_param']['num_parallel_tree'])
|
||||
except KeyError: # gblinear
|
||||
num_parallel_tree = 1
|
||||
|
||||
config = json.loads(bst.save_config())
|
||||
booster = config['learner']['gradient_booster']['name']
|
||||
if booster == 'gblinear':
|
||||
num_parallel_tree = 0
|
||||
elif booster == 'dart':
|
||||
num_parallel_tree = int(
|
||||
config['learner']['gradient_booster']['gbtree']['gbtree_train_param'][
|
||||
'num_parallel_tree'
|
||||
]
|
||||
)
|
||||
elif booster == 'gbtree':
|
||||
num_parallel_tree = int(
|
||||
config['learner']['gradient_booster']['gbtree_train_param'][
|
||||
'num_parallel_tree']
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unknown booster: {booster}')
|
||||
num_groups = int(config['learner']['learner_model_param']['num_class'])
|
||||
num_groups = 1 if num_groups == 0 else num_groups
|
||||
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree * num_groups
|
||||
bst.best_ntree_limit = ((bst.best_iteration + 1) * num_parallel_tree * num_groups)
|
||||
|
||||
# Copy to serialise and unserialise booster to reset state and free
|
||||
# training memory
|
||||
return bst.copy()
|
||||
|
||||
Reference in New Issue
Block a user