Fix best_ntree_limit. (#6569)
This commit is contained in:
parent
195a41cef1
commit
516a93d25c
@ -93,12 +93,16 @@ def _train_internal(params, dtrain,
|
|||||||
bst.best_iteration = int(bst.attr('best_iteration'))
|
bst.best_iteration = int(bst.attr('best_iteration'))
|
||||||
else:
|
else:
|
||||||
bst.best_iteration = bst.num_boosted_rounds() - 1
|
bst.best_iteration = bst.num_boosted_rounds() - 1
|
||||||
|
config = json.loads(bst.save_config())
|
||||||
try:
|
try:
|
||||||
num_parallel_tree = int(json.loads(bst.save_config())['learner'][
|
num_parallel_tree = int(config['learner'][
|
||||||
'gradient_booster']['gbtree_train_param']['num_parallel_tree'])
|
'gradient_booster']['gbtree_train_param']['num_parallel_tree'])
|
||||||
except KeyError: # gblinear
|
except KeyError: # gblinear
|
||||||
num_parallel_tree = 1
|
num_parallel_tree = 1
|
||||||
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
|
||||||
|
num_groups = int(config['learner']['learner_model_param']['num_class'])
|
||||||
|
num_groups = 1 if num_groups == 0 else num_groups
|
||||||
|
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree * num_groups
|
||||||
# Copy to serialise and unserialise booster to reset state and free
|
# Copy to serialise and unserialise booster to reset state and free
|
||||||
# training memory
|
# training memory
|
||||||
return bst.copy()
|
return bst.copy()
|
||||||
|
|||||||
@ -119,13 +119,13 @@ class TestTrainingContinuation:
|
|||||||
gbdt_05 = xgb.train(xgb_params_03, dtrain_5class,
|
gbdt_05 = xgb.train(xgb_params_03, dtrain_5class,
|
||||||
num_boost_round=7)
|
num_boost_round=7)
|
||||||
assert gbdt_05.best_ntree_limit == (
|
assert gbdt_05.best_ntree_limit == (
|
||||||
gbdt_05.best_iteration + 1) * self.num_parallel_tree
|
gbdt_05.best_iteration + 1) * self.num_parallel_tree * 5
|
||||||
gbdt_05 = xgb.train(xgb_params_03,
|
gbdt_05 = xgb.train(xgb_params_03,
|
||||||
dtrain_5class,
|
dtrain_5class,
|
||||||
num_boost_round=3,
|
num_boost_round=3,
|
||||||
xgb_model=gbdt_05)
|
xgb_model=gbdt_05)
|
||||||
assert gbdt_05.best_ntree_limit == (
|
assert gbdt_05.best_ntree_limit == (
|
||||||
gbdt_05.best_iteration + 1) * self.num_parallel_tree
|
gbdt_05.best_iteration + 1) * self.num_parallel_tree * 5
|
||||||
|
|
||||||
res1 = gbdt_05.predict(dtrain_5class)
|
res1 = gbdt_05.predict(dtrain_5class)
|
||||||
res2 = gbdt_05.predict(dtrain_5class,
|
res2 = gbdt_05.predict(dtrain_5class,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user