From d0ec65520a008980dff611d9fd0cb53c1389627f Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 11 Jan 2021 01:46:05 +0800 Subject: [PATCH] [backport] Fix `best_ntree_limit` for dart and gblinear. (#6579) (#6587) * [backport] Fix `best_ntree_limit` for dart and gblinear. (#6579) * Backport num group test fix. --- python-package/xgboost/training.py | 24 ++++++++++++++++++- tests/python/test_training_continuation.py | 4 ++-- tests/python/test_with_sklearn.py | 28 ++++++++++++++++++++++ 3 files changed, 53 insertions(+), 3 deletions(-) diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index 8db3a9798..34ad027d1 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -4,6 +4,7 @@ """Training Library containing training routines.""" import warnings import copy +import json import numpy as np from .core import Booster, XGBoostError @@ -123,7 +124,28 @@ def _train_internal(params, dtrain, bst.best_iteration = int(bst.attr('best_iteration')) else: bst.best_iteration = nboost - 1 - bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree + + config = json.loads(bst.save_config()) + booster = config['learner']['gradient_booster']['name'] + if booster == 'gblinear': + num_parallel_tree = 0 + elif booster == 'dart': + num_parallel_tree = int( + config['learner']['gradient_booster']['gbtree']['gbtree_train_param'][ + 'num_parallel_tree' + ] + ) + elif booster == 'gbtree': + num_parallel_tree = int( + config['learner']['gradient_booster']['gbtree_train_param'][ + 'num_parallel_tree'] + ) + else: + raise ValueError(f'Unknown booster: {booster}') + num_groups = int(config['learner']['learner_model_param']['num_class']) + num_groups = 1 if num_groups == 0 else num_groups + bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree * num_groups + # Copy to serialise and unserialise booster to reset state and free # training memory return bst.copy() diff --git a/tests/python/test_training_continuation.py b/tests/python/test_training_continuation.py index 9990ca61b..2c4e577d2 100644 --- a/tests/python/test_training_continuation.py +++ b/tests/python/test_training_continuation.py @@ -123,13 +123,13 @@ class TestTrainingContinuation: gbdt_05 = xgb.train(xgb_params_03, dtrain_5class, num_boost_round=7) assert gbdt_05.best_ntree_limit == ( - gbdt_05.best_iteration + 1) * self.num_parallel_tree + gbdt_05.best_iteration + 1) * self.num_parallel_tree * 5 gbdt_05 = xgb.train(xgb_params_03, dtrain_5class, num_boost_round=3, xgb_model=gbdt_05) assert gbdt_05.best_ntree_limit == ( - gbdt_05.best_iteration + 1) * self.num_parallel_tree + gbdt_05.best_iteration + 1) * self.num_parallel_tree * 5 res1 = gbdt_05.predict(dtrain_5class) res2 = gbdt_05.predict(dtrain_5class, diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 8a4f17ffb..f1d5f2bef 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -78,6 +78,34 @@ def test_multiclass_classification(): check_pred(preds4, labels, output_margin=False) +def test_best_ntree_limit(): + from sklearn.datasets import load_iris + + X, y = load_iris(return_X_y=True) + + def train(booster, forest): + rounds = 4 + cls = xgb.XGBClassifier( + n_estimators=rounds, num_parallel_tree=forest, booster=booster + ).fit( + X, y, eval_set=[(X, y)], early_stopping_rounds=3 + ) + + if forest: + assert cls.best_ntree_limit == rounds * forest * cls.n_classes_ + else: + assert cls.best_ntree_limit == 0 + + # best_ntree_limit is used by default, assert that under gblinear it's + # automatically ignored due to being 0. + cls.predict(X) + + num_parallel_tree = 4 + train('gbtree', num_parallel_tree) + train('dart', num_parallel_tree) + train('gblinear', None) + + def test_ranking(): # generate random data x_train = np.random.rand(1000, 10)