From 2f57bbde3c5feeb04c26fb2bb1391c63568662cf Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 9 Jan 2024 09:54:39 +0800 Subject: [PATCH] Additional tests for attributes and model booosted rounds. (#9962) --- src/learner.cc | 8 ++- tests/python/test_basic_models.py | 82 ++++++++++++++++++++----------- tests/python/test_model_io.py | 30 +++++++++++ 3 files changed, 90 insertions(+), 30 deletions(-) diff --git a/src/learner.cc b/src/learner.cc index 6b0fd7e4b..db72f7164 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -535,8 +535,7 @@ class LearnerConfiguration : public Learner { tparam_.booster = get(gradient_booster["name"]); if (!gbm_) { - gbm_.reset(GradientBooster::Create(tparam_.booster, - &ctx_, &learner_model_param_)); + gbm_.reset(GradientBooster::Create(tparam_.booster, &ctx_, &learner_model_param_)); } gbm_->LoadConfig(gradient_booster); @@ -1095,6 +1094,11 @@ class LearnerIO : public LearnerConfiguration { std::vector > extra_attr; mparam.contain_extra_attrs = 1; + if (!this->feature_names_.empty() || !this->feature_types_.empty()) { + LOG(WARNING) << "feature names and feature types are being disregarded, use JSON/UBJSON " + "format instead."; + } + { // Similar to JSON model IO, we save the objective. Json j_obj { Object() }; diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index ca35c4e91..828c24862 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -1,5 +1,4 @@ import json -import locale import os import tempfile @@ -110,20 +109,39 @@ class TestModels: predt_2 = bst.predict(dtrain) assert np.all(np.abs(predt_2 - predt_1) < 1e-6) - def test_boost_from_existing_model(self): + def test_boost_from_existing_model(self) -> None: X, _ = tm.load_agaricus(__file__) - booster = xgb.train({'tree_method': 'hist'}, X, num_boost_round=4) + booster = xgb.train({"tree_method": "hist"}, X, num_boost_round=4) assert booster.num_boosted_rounds() == 4 - booster = xgb.train({'tree_method': 'hist'}, X, num_boost_round=4, - xgb_model=booster) + booster.set_param({"tree_method": "approx"}) + assert booster.num_boosted_rounds() == 4 + booster = xgb.train( + {"tree_method": "hist"}, X, num_boost_round=4, xgb_model=booster + ) assert booster.num_boosted_rounds() == 8 - booster = xgb.train({'updater': 'prune', 'process_type': 'update'}, X, - num_boost_round=4, xgb_model=booster) + with pytest.warns(UserWarning, match="`updater`"): + booster = xgb.train( + {"updater": "prune", "process_type": "update"}, + X, + num_boost_round=4, + xgb_model=booster, + ) # Trees are moved for update, the rounds is reduced. This test is # written for being compatible with current code (1.0.0). If the # behaviour is considered sub-optimal, feel free to change. assert booster.num_boosted_rounds() == 4 + booster = xgb.train({"booster": "gblinear"}, X, num_boost_round=4) + assert booster.num_boosted_rounds() == 4 + booster.set_param({"updater": "coord_descent"}) + assert booster.num_boosted_rounds() == 4 + booster.set_param({"updater": "shotgun"}) + assert booster.num_boosted_rounds() == 4 + booster = xgb.train( + {"booster": "gblinear"}, X, num_boost_round=4, xgb_model=booster + ) + assert booster.num_boosted_rounds() == 8 + def run_custom_objective(self, tree_method=None): param = { 'max_depth': 2, @@ -307,25 +325,6 @@ class TestModels: for d in text_dump: assert d.find(r"feature \"2\"") != -1 - @pytest.mark.skipif(**tm.no_sklearn()) - def test_attributes(self): - from sklearn.datasets import load_iris - - X, y = load_iris(return_X_y=True) - cls = xgb.XGBClassifier(n_estimators=2) - cls.fit(X, y, early_stopping_rounds=1, eval_set=[(X, y)]) - assert cls.get_booster().best_iteration == cls.n_estimators - 1 - assert cls.best_iteration == cls.get_booster().best_iteration - - with tempfile.TemporaryDirectory() as tmpdir: - path = os.path.join(tmpdir, "cls.json") - cls.save_model(path) - - cls = xgb.XGBClassifier(n_estimators=2) - cls.load_model(path) - assert cls.get_booster().best_iteration == cls.n_estimators - 1 - assert cls.best_iteration == cls.get_booster().best_iteration - def run_slice( self, booster: xgb.Booster, @@ -493,18 +492,23 @@ class TestModels: np.testing.assert_allclose(predt0, predt1, atol=1e-5) @pytest.mark.skipif(**tm.no_pandas()) - def test_feature_info(self): + @pytest.mark.parametrize("ext", ["json", "ubj"]) + def test_feature_info(self, ext: str) -> None: import pandas as pd + # make data rows = 100 cols = 10 X = rng.randn(rows, cols) y = rng.randn(rows) + + # Test with pandas, which has feature info. feature_names = ["test_feature_" + str(i) for i in range(cols)] X_pd = pd.DataFrame(X, columns=feature_names) X_pd[f"test_feature_{3}"] = X_pd.iloc[:, 3].astype(np.int32) Xy = xgb.DMatrix(X_pd, y) + assert Xy.feature_types is not None assert Xy.feature_types[3] == "int" booster = xgb.train({}, dtrain=Xy, num_boost_round=1) @@ -513,10 +517,32 @@ class TestModels: assert booster.feature_types == Xy.feature_types with tempfile.TemporaryDirectory() as tmpdir: - path = tmpdir + "model.json" + path = tmpdir + f"model.{ext}" booster.save_model(path) booster = xgb.Booster() booster.load_model(path) assert booster.feature_names == Xy.feature_names assert booster.feature_types == Xy.feature_types + + # Test with numpy, no feature info is set + Xy = xgb.DMatrix(X, y) + assert Xy.feature_names is None + assert Xy.feature_types is None + + booster = xgb.train({}, dtrain=Xy, num_boost_round=1) + assert booster.feature_names is None + assert booster.feature_types is None + + # test explicitly set + fns = [str(i) for i in range(cols)] + booster.feature_names = fns + + assert booster.feature_names == fns + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, f"model.{ext}") + booster.save_model(path) + + booster = xgb.Booster(model_file=path) + assert booster.feature_names == fns diff --git a/tests/python/test_model_io.py b/tests/python/test_model_io.py index 884bba08f..df0fff22e 100644 --- a/tests/python/test_model_io.py +++ b/tests/python/test_model_io.py @@ -466,3 +466,33 @@ def test_with_sklearn_obj_metric() -> None: assert not callable(reg_2.objective) assert not callable(reg_2.eval_metric) assert reg_2.eval_metric is None + + +@pytest.mark.skipif(**tm.no_sklearn()) +def test_attributes() -> None: + from sklearn.datasets import load_iris + + X, y = load_iris(return_X_y=True) + clf = xgb.XGBClassifier(n_estimators=2, early_stopping_rounds=1) + clf.fit(X, y, eval_set=[(X, y)]) + best_iteration = clf.get_booster().best_iteration + assert best_iteration is not None + assert clf.n_estimators is not None + assert best_iteration == clf.n_estimators - 1 + + best_iteration = clf.best_iteration + assert best_iteration == clf.get_booster().best_iteration + + clf.get_booster().set_attr(foo="bar") + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "clf.json") + clf.save_model(path) + + clf = xgb.XGBClassifier(n_estimators=2) + clf.load_model(path) + assert clf.n_estimators is not None + assert clf.get_booster().best_iteration == clf.n_estimators - 1 + assert clf.best_iteration == clf.get_booster().best_iteration + + assert clf.get_booster().attributes()["foo"] == "bar"