From 516a93d25c3b6899558700430ffc99a29ea21e1a Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sun, 3 Jan 2021 05:58:54 +0800 Subject: [PATCH] Fix `best_ntree_limit`. (#6569) --- python-package/xgboost/training.py | 8 ++++++-- tests/python/test_training_continuation.py | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index 1467fe726..007e4b186 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -93,12 +93,16 @@ def _train_internal(params, dtrain, bst.best_iteration = int(bst.attr('best_iteration')) else: bst.best_iteration = bst.num_boosted_rounds() - 1 + config = json.loads(bst.save_config()) try: - num_parallel_tree = int(json.loads(bst.save_config())['learner'][ + num_parallel_tree = int(config['learner'][ 'gradient_booster']['gbtree_train_param']['num_parallel_tree']) except KeyError: # gblinear num_parallel_tree = 1 - bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree + + num_groups = int(config['learner']['learner_model_param']['num_class']) + num_groups = 1 if num_groups == 0 else num_groups + bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree * num_groups # Copy to serialise and unserialise booster to reset state and free # training memory return bst.copy() diff --git a/tests/python/test_training_continuation.py b/tests/python/test_training_continuation.py index 762efaf95..e56fc9b2d 100644 --- a/tests/python/test_training_continuation.py +++ b/tests/python/test_training_continuation.py @@ -119,13 +119,13 @@ class TestTrainingContinuation: gbdt_05 = xgb.train(xgb_params_03, dtrain_5class, num_boost_round=7) assert gbdt_05.best_ntree_limit == ( - gbdt_05.best_iteration + 1) * self.num_parallel_tree + gbdt_05.best_iteration + 1) * self.num_parallel_tree * 5 gbdt_05 = xgb.train(xgb_params_03, dtrain_5class, num_boost_round=3, xgb_model=gbdt_05) assert gbdt_05.best_ntree_limit == ( - gbdt_05.best_iteration + 1) * self.num_parallel_tree + gbdt_05.best_iteration + 1) * self.num_parallel_tree * 5 res1 = gbdt_05.predict(dtrain_5class) res2 = gbdt_05.predict(dtrain_5class,