* [backport] Fix `best_ntree_limit` for dart and gblinear. (#6579) * Backport num group test fix.
This commit is contained in:
parent
7aec915dcd
commit
d0ec65520a
@ -4,6 +4,7 @@
|
|||||||
"""Training Library containing training routines."""
|
"""Training Library containing training routines."""
|
||||||
import warnings
|
import warnings
|
||||||
import copy
|
import copy
|
||||||
|
import json
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .core import Booster, XGBoostError
|
from .core import Booster, XGBoostError
|
||||||
@ -123,7 +124,28 @@ def _train_internal(params, dtrain,
|
|||||||
bst.best_iteration = int(bst.attr('best_iteration'))
|
bst.best_iteration = int(bst.attr('best_iteration'))
|
||||||
else:
|
else:
|
||||||
bst.best_iteration = nboost - 1
|
bst.best_iteration = nboost - 1
|
||||||
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
|
||||||
|
config = json.loads(bst.save_config())
|
||||||
|
booster = config['learner']['gradient_booster']['name']
|
||||||
|
if booster == 'gblinear':
|
||||||
|
num_parallel_tree = 0
|
||||||
|
elif booster == 'dart':
|
||||||
|
num_parallel_tree = int(
|
||||||
|
config['learner']['gradient_booster']['gbtree']['gbtree_train_param'][
|
||||||
|
'num_parallel_tree'
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif booster == 'gbtree':
|
||||||
|
num_parallel_tree = int(
|
||||||
|
config['learner']['gradient_booster']['gbtree_train_param'][
|
||||||
|
'num_parallel_tree']
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unknown booster: {booster}')
|
||||||
|
num_groups = int(config['learner']['learner_model_param']['num_class'])
|
||||||
|
num_groups = 1 if num_groups == 0 else num_groups
|
||||||
|
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree * num_groups
|
||||||
|
|
||||||
# Copy to serialise and unserialise booster to reset state and free
|
# Copy to serialise and unserialise booster to reset state and free
|
||||||
# training memory
|
# training memory
|
||||||
return bst.copy()
|
return bst.copy()
|
||||||
|
|||||||
@ -123,13 +123,13 @@ class TestTrainingContinuation:
|
|||||||
gbdt_05 = xgb.train(xgb_params_03, dtrain_5class,
|
gbdt_05 = xgb.train(xgb_params_03, dtrain_5class,
|
||||||
num_boost_round=7)
|
num_boost_round=7)
|
||||||
assert gbdt_05.best_ntree_limit == (
|
assert gbdt_05.best_ntree_limit == (
|
||||||
gbdt_05.best_iteration + 1) * self.num_parallel_tree
|
gbdt_05.best_iteration + 1) * self.num_parallel_tree * 5
|
||||||
gbdt_05 = xgb.train(xgb_params_03,
|
gbdt_05 = xgb.train(xgb_params_03,
|
||||||
dtrain_5class,
|
dtrain_5class,
|
||||||
num_boost_round=3,
|
num_boost_round=3,
|
||||||
xgb_model=gbdt_05)
|
xgb_model=gbdt_05)
|
||||||
assert gbdt_05.best_ntree_limit == (
|
assert gbdt_05.best_ntree_limit == (
|
||||||
gbdt_05.best_iteration + 1) * self.num_parallel_tree
|
gbdt_05.best_iteration + 1) * self.num_parallel_tree * 5
|
||||||
|
|
||||||
res1 = gbdt_05.predict(dtrain_5class)
|
res1 = gbdt_05.predict(dtrain_5class)
|
||||||
res2 = gbdt_05.predict(dtrain_5class,
|
res2 = gbdt_05.predict(dtrain_5class,
|
||||||
|
|||||||
@ -78,6 +78,34 @@ def test_multiclass_classification():
|
|||||||
check_pred(preds4, labels, output_margin=False)
|
check_pred(preds4, labels, output_margin=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_best_ntree_limit():
|
||||||
|
from sklearn.datasets import load_iris
|
||||||
|
|
||||||
|
X, y = load_iris(return_X_y=True)
|
||||||
|
|
||||||
|
def train(booster, forest):
|
||||||
|
rounds = 4
|
||||||
|
cls = xgb.XGBClassifier(
|
||||||
|
n_estimators=rounds, num_parallel_tree=forest, booster=booster
|
||||||
|
).fit(
|
||||||
|
X, y, eval_set=[(X, y)], early_stopping_rounds=3
|
||||||
|
)
|
||||||
|
|
||||||
|
if forest:
|
||||||
|
assert cls.best_ntree_limit == rounds * forest * cls.n_classes_
|
||||||
|
else:
|
||||||
|
assert cls.best_ntree_limit == 0
|
||||||
|
|
||||||
|
# best_ntree_limit is used by default, assert that under gblinear it's
|
||||||
|
# automatically ignored due to being 0.
|
||||||
|
cls.predict(X)
|
||||||
|
|
||||||
|
num_parallel_tree = 4
|
||||||
|
train('gbtree', num_parallel_tree)
|
||||||
|
train('dart', num_parallel_tree)
|
||||||
|
train('gblinear', None)
|
||||||
|
|
||||||
|
|
||||||
def test_ranking():
|
def test_ranking():
|
||||||
# generate random data
|
# generate random data
|
||||||
x_train = np.random.rand(1000, 10)
|
x_train = np.random.rand(1000, 10)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user