Revert ntree limit fix (#6616)
The old (before fix) best_ntree_limit ignores the num_class parameters, which is incorrect. In before we workarounded it in c++ layer to avoid possible breaking changes on other language bindings. But the Python interpretation stayed incorrect. The PR fixed that in Python to consider num_class, but didn't remove the old workaround, so tree calculation in predictor is incorrect, see PredictBatch in CPUPredictor.
This commit is contained in:
parent
d132933550
commit
d6d72de339
@ -109,22 +109,19 @@ def _train_internal(params, dtrain,
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f'Unknown booster: {booster}')
|
raise ValueError(f'Unknown booster: {booster}')
|
||||||
|
|
||||||
num_groups = int(config['learner']['learner_model_param']['num_class'])
|
|
||||||
num_groups = 1 if num_groups == 0 else num_groups
|
|
||||||
if bst.attr('best_score') is not None:
|
if bst.attr('best_score') is not None:
|
||||||
bst.best_score = float(bst.attr('best_score'))
|
bst.best_score = float(bst.attr('best_score'))
|
||||||
bst.best_iteration = int(bst.attr('best_iteration'))
|
bst.best_iteration = int(bst.attr('best_iteration'))
|
||||||
|
# num_class is handled internally
|
||||||
bst.set_attr(
|
bst.set_attr(
|
||||||
best_ntree_limit=str(
|
best_ntree_limit=str((bst.best_iteration + 1) * num_parallel_tree)
|
||||||
(bst.best_iteration + 1) * num_parallel_tree * num_groups
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
bst.best_ntree_limit = int(bst.attr("best_ntree_limit"))
|
bst.best_ntree_limit = int(bst.attr("best_ntree_limit"))
|
||||||
else:
|
else:
|
||||||
# Due to compatibility with version older than 1.4, these attributes are added
|
# Due to compatibility with version older than 1.4, these attributes are added
|
||||||
# to Python object even if early stopping is not used.
|
# to Python object even if early stopping is not used.
|
||||||
bst.best_iteration = bst.num_boosted_rounds() - 1
|
bst.best_iteration = bst.num_boosted_rounds() - 1
|
||||||
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree * num_groups
|
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
||||||
|
|
||||||
# Copy to serialise and unserialise booster to reset state and free
|
# Copy to serialise and unserialise booster to reset state and free
|
||||||
# training memory
|
# training memory
|
||||||
@ -165,9 +162,10 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
If there's more than one metric in the **eval_metric** parameter given in
|
If there's more than one metric in the **eval_metric** parameter given in
|
||||||
**params**, the last metric will be used for early stopping.
|
**params**, the last metric will be used for early stopping.
|
||||||
If early stopping occurs, the model will have three additional fields:
|
If early stopping occurs, the model will have three additional fields:
|
||||||
``bst.best_score``, ``bst.best_iteration`` and ``bst.best_ntree_limit``. (Use
|
``bst.best_score``, ``bst.best_iteration`` and ``bst.best_ntree_limit``. Use
|
||||||
``bst.best_ntree_limit`` to get the correct value if ``num_parallel_tree`` and/or
|
``bst.best_ntree_limit`` to get the correct value if ``num_parallel_tree`` and/or
|
||||||
``num_class`` appears in the parameters)
|
``num_class`` appears in the parameters. ``best_ntree_limit`` is the result of
|
||||||
|
``num_parallel_tree * best_iteration``.
|
||||||
evals_result: dict
|
evals_result: dict
|
||||||
This dictionary stores the evaluation results of all the items in watchlist.
|
This dictionary stores the evaluation results of all the items in watchlist.
|
||||||
|
|
||||||
|
|||||||
@ -347,7 +347,7 @@ class TestModels:
|
|||||||
X, y = load_iris(return_X_y=True)
|
X, y = load_iris(return_X_y=True)
|
||||||
cls = xgb.XGBClassifier(n_estimators=2)
|
cls = xgb.XGBClassifier(n_estimators=2)
|
||||||
cls.fit(X, y, early_stopping_rounds=1, eval_set=[(X, y)])
|
cls.fit(X, y, early_stopping_rounds=1, eval_set=[(X, y)])
|
||||||
assert cls.get_booster().best_ntree_limit == 2 * cls.n_classes_
|
assert cls.get_booster().best_ntree_limit == 2
|
||||||
assert cls.best_ntree_limit == cls.get_booster().best_ntree_limit
|
assert cls.best_ntree_limit == cls.get_booster().best_ntree_limit
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
@ -356,7 +356,7 @@ class TestModels:
|
|||||||
|
|
||||||
cls = xgb.XGBClassifier(n_estimators=2)
|
cls = xgb.XGBClassifier(n_estimators=2)
|
||||||
cls.load_model(path)
|
cls.load_model(path)
|
||||||
assert cls.get_booster().best_ntree_limit == 2 * cls.n_classes_
|
assert cls.get_booster().best_ntree_limit == 2
|
||||||
assert cls.best_ntree_limit == cls.get_booster().best_ntree_limit
|
assert cls.best_ntree_limit == cls.get_booster().best_ntree_limit
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
|
|||||||
@ -33,9 +33,15 @@ def run_predict_leaf(predictor):
|
|||||||
y = rng.randint(low=0, high=classes, size=rows)
|
y = rng.randint(low=0, high=classes, size=rows)
|
||||||
m = xgb.DMatrix(X, y)
|
m = xgb.DMatrix(X, y)
|
||||||
booster = xgb.train(
|
booster = xgb.train(
|
||||||
{'num_parallel_tree': num_parallel_tree, 'num_class': classes,
|
{
|
||||||
'predictor': predictor, 'tree_method': 'hist'}, m,
|
"num_parallel_tree": num_parallel_tree,
|
||||||
num_boost_round=num_boost_round)
|
"num_class": classes,
|
||||||
|
"predictor": predictor,
|
||||||
|
"tree_method": "hist",
|
||||||
|
},
|
||||||
|
m,
|
||||||
|
num_boost_round=num_boost_round,
|
||||||
|
)
|
||||||
|
|
||||||
empty = xgb.DMatrix(np.ones(shape=(0, cols)))
|
empty = xgb.DMatrix(np.ones(shape=(0, cols)))
|
||||||
empty_leaf = booster.predict(empty, pred_leaf=True)
|
empty_leaf = booster.predict(empty, pred_leaf=True)
|
||||||
@ -52,12 +58,19 @@ def run_predict_leaf(predictor):
|
|||||||
end = classes * num_parallel_tree * (j + 1)
|
end = classes * num_parallel_tree * (j + 1)
|
||||||
layer = row[start: end]
|
layer = row[start: end]
|
||||||
for c in range(classes):
|
for c in range(classes):
|
||||||
tree_group = layer[c * num_parallel_tree:
|
tree_group = layer[c * num_parallel_tree: (c + 1) * num_parallel_tree]
|
||||||
(c+1) * num_parallel_tree]
|
|
||||||
assert tree_group.shape[0] == num_parallel_tree
|
assert tree_group.shape[0] == num_parallel_tree
|
||||||
# no subsampling so tree in same forest should output same
|
# no subsampling so tree in same forest should output same
|
||||||
# leaf.
|
# leaf.
|
||||||
assert np.all(tree_group == tree_group[0])
|
assert np.all(tree_group == tree_group[0])
|
||||||
|
|
||||||
|
ntree_limit = 2
|
||||||
|
sliced = booster.predict(
|
||||||
|
m, pred_leaf=True, ntree_limit=num_parallel_tree * ntree_limit
|
||||||
|
)
|
||||||
|
first = sliced[0, ...]
|
||||||
|
|
||||||
|
assert first.shape[0] == classes * num_parallel_tree * ntree_limit
|
||||||
return leaf
|
return leaf
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -119,13 +119,13 @@ class TestTrainingContinuation:
|
|||||||
gbdt_05 = xgb.train(xgb_params_03, dtrain_5class,
|
gbdt_05 = xgb.train(xgb_params_03, dtrain_5class,
|
||||||
num_boost_round=7)
|
num_boost_round=7)
|
||||||
assert gbdt_05.best_ntree_limit == (
|
assert gbdt_05.best_ntree_limit == (
|
||||||
gbdt_05.best_iteration + 1) * self.num_parallel_tree * 5
|
gbdt_05.best_iteration + 1) * self.num_parallel_tree
|
||||||
gbdt_05 = xgb.train(xgb_params_03,
|
gbdt_05 = xgb.train(xgb_params_03,
|
||||||
dtrain_5class,
|
dtrain_5class,
|
||||||
num_boost_round=3,
|
num_boost_round=3,
|
||||||
xgb_model=gbdt_05)
|
xgb_model=gbdt_05)
|
||||||
assert gbdt_05.best_ntree_limit == (
|
assert gbdt_05.best_ntree_limit == (
|
||||||
gbdt_05.best_iteration + 1) * self.num_parallel_tree * 5
|
gbdt_05.best_iteration + 1) * self.num_parallel_tree
|
||||||
|
|
||||||
res1 = gbdt_05.predict(dtrain_5class)
|
res1 = gbdt_05.predict(dtrain_5class)
|
||||||
res2 = gbdt_05.predict(dtrain_5class,
|
res2 = gbdt_05.predict(dtrain_5class,
|
||||||
|
|||||||
@ -933,9 +933,9 @@ class TestWithDask:
|
|||||||
def test_feature_weights(self, client: "Client") -> None:
|
def test_feature_weights(self, client: "Client") -> None:
|
||||||
kRows = 1024
|
kRows = 1024
|
||||||
kCols = 64
|
kCols = 64
|
||||||
|
rng = da.random.RandomState(1994)
|
||||||
X = da.random.random((kRows, kCols), chunks=(32, -1))
|
X = rng.random_sample((kRows, kCols), chunks=(32, -1))
|
||||||
y = da.random.random(kRows, chunks=32)
|
y = rng.random_sample(kRows, chunks=32)
|
||||||
|
|
||||||
fw = np.ones(shape=(kCols,))
|
fw = np.ones(shape=(kCols,))
|
||||||
for i in range(kCols):
|
for i in range(kCols):
|
||||||
|
|||||||
@ -106,7 +106,7 @@ def test_best_ntree_limit():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if forest:
|
if forest:
|
||||||
assert cls.best_ntree_limit == rounds * forest * cls.n_classes_
|
assert cls.best_ntree_limit == rounds * forest
|
||||||
else:
|
else:
|
||||||
assert cls.best_ntree_limit == 0
|
assert cls.best_ntree_limit == 0
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user