Additional tests for attributes and model booosted rounds. (#9962)

This commit is contained in:
Jiaming Yuan 2024-01-09 09:54:39 +08:00 committed by GitHub
parent bed0349954
commit 2f57bbde3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 90 additions and 30 deletions

View File

@ -535,8 +535,7 @@ class LearnerConfiguration : public Learner {
tparam_.booster = get<String>(gradient_booster["name"]); tparam_.booster = get<String>(gradient_booster["name"]);
if (!gbm_) { if (!gbm_) {
gbm_.reset(GradientBooster::Create(tparam_.booster, gbm_.reset(GradientBooster::Create(tparam_.booster, &ctx_, &learner_model_param_));
&ctx_, &learner_model_param_));
} }
gbm_->LoadConfig(gradient_booster); gbm_->LoadConfig(gradient_booster);
@ -1095,6 +1094,11 @@ class LearnerIO : public LearnerConfiguration {
std::vector<std::pair<std::string, std::string> > extra_attr; std::vector<std::pair<std::string, std::string> > extra_attr;
mparam.contain_extra_attrs = 1; mparam.contain_extra_attrs = 1;
if (!this->feature_names_.empty() || !this->feature_types_.empty()) {
LOG(WARNING) << "feature names and feature types are being disregarded, use JSON/UBJSON "
"format instead.";
}
{ {
// Similar to JSON model IO, we save the objective. // Similar to JSON model IO, we save the objective.
Json j_obj { Object() }; Json j_obj { Object() };

View File

@ -1,5 +1,4 @@
import json import json
import locale
import os import os
import tempfile import tempfile
@ -110,20 +109,39 @@ class TestModels:
predt_2 = bst.predict(dtrain) predt_2 = bst.predict(dtrain)
assert np.all(np.abs(predt_2 - predt_1) < 1e-6) assert np.all(np.abs(predt_2 - predt_1) < 1e-6)
def test_boost_from_existing_model(self): def test_boost_from_existing_model(self) -> None:
X, _ = tm.load_agaricus(__file__) X, _ = tm.load_agaricus(__file__)
booster = xgb.train({'tree_method': 'hist'}, X, num_boost_round=4) booster = xgb.train({"tree_method": "hist"}, X, num_boost_round=4)
assert booster.num_boosted_rounds() == 4 assert booster.num_boosted_rounds() == 4
booster = xgb.train({'tree_method': 'hist'}, X, num_boost_round=4, booster.set_param({"tree_method": "approx"})
xgb_model=booster) assert booster.num_boosted_rounds() == 4
booster = xgb.train(
{"tree_method": "hist"}, X, num_boost_round=4, xgb_model=booster
)
assert booster.num_boosted_rounds() == 8 assert booster.num_boosted_rounds() == 8
booster = xgb.train({'updater': 'prune', 'process_type': 'update'}, X, with pytest.warns(UserWarning, match="`updater`"):
num_boost_round=4, xgb_model=booster) booster = xgb.train(
{"updater": "prune", "process_type": "update"},
X,
num_boost_round=4,
xgb_model=booster,
)
# Trees are moved for update, the rounds is reduced. This test is # Trees are moved for update, the rounds is reduced. This test is
# written for being compatible with current code (1.0.0). If the # written for being compatible with current code (1.0.0). If the
# behaviour is considered sub-optimal, feel free to change. # behaviour is considered sub-optimal, feel free to change.
assert booster.num_boosted_rounds() == 4 assert booster.num_boosted_rounds() == 4
booster = xgb.train({"booster": "gblinear"}, X, num_boost_round=4)
assert booster.num_boosted_rounds() == 4
booster.set_param({"updater": "coord_descent"})
assert booster.num_boosted_rounds() == 4
booster.set_param({"updater": "shotgun"})
assert booster.num_boosted_rounds() == 4
booster = xgb.train(
{"booster": "gblinear"}, X, num_boost_round=4, xgb_model=booster
)
assert booster.num_boosted_rounds() == 8
def run_custom_objective(self, tree_method=None): def run_custom_objective(self, tree_method=None):
param = { param = {
'max_depth': 2, 'max_depth': 2,
@ -307,25 +325,6 @@ class TestModels:
for d in text_dump: for d in text_dump:
assert d.find(r"feature \"2\"") != -1 assert d.find(r"feature \"2\"") != -1
@pytest.mark.skipif(**tm.no_sklearn())
def test_attributes(self):
from sklearn.datasets import load_iris
X, y = load_iris(return_X_y=True)
cls = xgb.XGBClassifier(n_estimators=2)
cls.fit(X, y, early_stopping_rounds=1, eval_set=[(X, y)])
assert cls.get_booster().best_iteration == cls.n_estimators - 1
assert cls.best_iteration == cls.get_booster().best_iteration
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "cls.json")
cls.save_model(path)
cls = xgb.XGBClassifier(n_estimators=2)
cls.load_model(path)
assert cls.get_booster().best_iteration == cls.n_estimators - 1
assert cls.best_iteration == cls.get_booster().best_iteration
def run_slice( def run_slice(
self, self,
booster: xgb.Booster, booster: xgb.Booster,
@ -493,18 +492,23 @@ class TestModels:
np.testing.assert_allclose(predt0, predt1, atol=1e-5) np.testing.assert_allclose(predt0, predt1, atol=1e-5)
@pytest.mark.skipif(**tm.no_pandas()) @pytest.mark.skipif(**tm.no_pandas())
def test_feature_info(self): @pytest.mark.parametrize("ext", ["json", "ubj"])
def test_feature_info(self, ext: str) -> None:
import pandas as pd import pandas as pd
# make data
rows = 100 rows = 100
cols = 10 cols = 10
X = rng.randn(rows, cols) X = rng.randn(rows, cols)
y = rng.randn(rows) y = rng.randn(rows)
# Test with pandas, which has feature info.
feature_names = ["test_feature_" + str(i) for i in range(cols)] feature_names = ["test_feature_" + str(i) for i in range(cols)]
X_pd = pd.DataFrame(X, columns=feature_names) X_pd = pd.DataFrame(X, columns=feature_names)
X_pd[f"test_feature_{3}"] = X_pd.iloc[:, 3].astype(np.int32) X_pd[f"test_feature_{3}"] = X_pd.iloc[:, 3].astype(np.int32)
Xy = xgb.DMatrix(X_pd, y) Xy = xgb.DMatrix(X_pd, y)
assert Xy.feature_types is not None
assert Xy.feature_types[3] == "int" assert Xy.feature_types[3] == "int"
booster = xgb.train({}, dtrain=Xy, num_boost_round=1) booster = xgb.train({}, dtrain=Xy, num_boost_round=1)
@ -513,10 +517,32 @@ class TestModels:
assert booster.feature_types == Xy.feature_types assert booster.feature_types == Xy.feature_types
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
path = tmpdir + "model.json" path = tmpdir + f"model.{ext}"
booster.save_model(path) booster.save_model(path)
booster = xgb.Booster() booster = xgb.Booster()
booster.load_model(path) booster.load_model(path)
assert booster.feature_names == Xy.feature_names assert booster.feature_names == Xy.feature_names
assert booster.feature_types == Xy.feature_types assert booster.feature_types == Xy.feature_types
# Test with numpy, no feature info is set
Xy = xgb.DMatrix(X, y)
assert Xy.feature_names is None
assert Xy.feature_types is None
booster = xgb.train({}, dtrain=Xy, num_boost_round=1)
assert booster.feature_names is None
assert booster.feature_types is None
# test explicitly set
fns = [str(i) for i in range(cols)]
booster.feature_names = fns
assert booster.feature_names == fns
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, f"model.{ext}")
booster.save_model(path)
booster = xgb.Booster(model_file=path)
assert booster.feature_names == fns

View File

@ -466,3 +466,33 @@ def test_with_sklearn_obj_metric() -> None:
assert not callable(reg_2.objective) assert not callable(reg_2.objective)
assert not callable(reg_2.eval_metric) assert not callable(reg_2.eval_metric)
assert reg_2.eval_metric is None assert reg_2.eval_metric is None
@pytest.mark.skipif(**tm.no_sklearn())
def test_attributes() -> None:
from sklearn.datasets import load_iris
X, y = load_iris(return_X_y=True)
clf = xgb.XGBClassifier(n_estimators=2, early_stopping_rounds=1)
clf.fit(X, y, eval_set=[(X, y)])
best_iteration = clf.get_booster().best_iteration
assert best_iteration is not None
assert clf.n_estimators is not None
assert best_iteration == clf.n_estimators - 1
best_iteration = clf.best_iteration
assert best_iteration == clf.get_booster().best_iteration
clf.get_booster().set_attr(foo="bar")
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "clf.json")
clf.save_model(path)
clf = xgb.XGBClassifier(n_estimators=2)
clf.load_model(path)
assert clf.n_estimators is not None
assert clf.get_booster().best_iteration == clf.n_estimators - 1
assert clf.best_iteration == clf.get_booster().best_iteration
assert clf.get_booster().attributes()["foo"] == "bar"