diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index 007e4b186..d80c2a6fa 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -93,16 +93,28 @@ def _train_internal(params, dtrain, bst.best_iteration = int(bst.attr('best_iteration')) else: bst.best_iteration = bst.num_boosted_rounds() - 1 - config = json.loads(bst.save_config()) - try: - num_parallel_tree = int(config['learner'][ - 'gradient_booster']['gbtree_train_param']['num_parallel_tree']) - except KeyError: # gblinear - num_parallel_tree = 1 + config = json.loads(bst.save_config()) + booster = config['learner']['gradient_booster']['name'] + if booster == 'gblinear': + num_parallel_tree = 0 + elif booster == 'dart': + num_parallel_tree = int( + config['learner']['gradient_booster']['gbtree']['gbtree_train_param'][ + 'num_parallel_tree' + ] + ) + elif booster == 'gbtree': + num_parallel_tree = int( + config['learner']['gradient_booster']['gbtree_train_param'][ + 'num_parallel_tree'] + ) + else: + raise ValueError(f'Unknown booster: {booster}') num_groups = int(config['learner']['learner_model_param']['num_class']) num_groups = 1 if num_groups == 0 else num_groups - bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree * num_groups + bst.best_ntree_limit = ((bst.best_iteration + 1) * num_parallel_tree * num_groups) + # Copy to serialise and unserialise booster to reset state and free # training memory return bst.copy() diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 99a1a5702..ffdfd7df5 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -92,6 +92,34 @@ def test_multiclass_classification(): assert proba.shape[1] == cls.n_classes_ +def test_best_ntree_limit(): + from sklearn.datasets import load_iris + + X, y = load_iris(return_X_y=True) + + def train(booster, forest): + rounds = 4 + cls = xgb.XGBClassifier( + n_estimators=rounds, num_parallel_tree=forest, booster=booster + ).fit( + X, y, eval_set=[(X, y)], early_stopping_rounds=3 + ) + + if forest: + assert cls.best_ntree_limit == rounds * forest * cls.n_classes_ + else: + assert cls.best_ntree_limit == 0 + + # best_ntree_limit is used by default, assert that under gblinear it's + # automatically ignored due to being 0. + cls.predict(X) + + num_parallel_tree = 4 + train('gbtree', num_parallel_tree) + train('dart', num_parallel_tree) + train('gblinear', None) + + def test_ranking(): # generate random data x_train = np.random.rand(1000, 10)