Test loading models with invalid file extensions. (#9955)

This commit is contained in:
Jiaming Yuan
2024-01-08 19:26:24 +08:00
committed by GitHub
parent 3ff3a5f1ed
commit 9a30bdd313
2 changed files with 82 additions and 10 deletions

View File

@@ -254,6 +254,68 @@ class TestBoosterIO:
# remove file
Path.unlink(save_path)
def test_invalid_postfix(self) -> None:
"""Test mis-specified model format, no special hanlding is expected, the
JSON/UBJ parser can emit parsing errors.
"""
X, y, w = tm.make_regression(64, 16, False)
booster = xgb.train({}, xgb.QuantileDMatrix(X, y, weight=w), num_boost_round=3)
def rename(src: str, dst: str) -> None:
if os.path.exists(dst):
# Windows cannot overwrite an existing file.
os.remove(dst)
os.rename(src, dst)
with tempfile.TemporaryDirectory() as tmpdir:
path_dep = os.path.join(tmpdir, "model.deprecated")
# save into deprecated format
with pytest.warns(UserWarning, match="UBJSON"):
booster.save_model(path_dep)
path_ubj = os.path.join(tmpdir, "model.ubj")
rename(path_dep, path_ubj)
with pytest.raises(ValueError, match="{"):
xgb.Booster(model_file=path_ubj)
path_json = os.path.join(tmpdir, "model.json")
rename(path_ubj, path_json)
with pytest.raises(ValueError, match="{"):
xgb.Booster(model_file=path_json)
# save into ubj format
booster.save_model(path_ubj)
rename(path_ubj, path_dep)
# deprecated is not a recognized format internally, XGBoost can guess the
# right format
xgb.Booster(model_file=path_dep)
rename(path_dep, path_json)
with pytest.raises(ValueError, match="Expecting"):
xgb.Booster(model_file=path_json)
# save into JSON format
booster.save_model(path_json)
rename(path_json, path_dep)
# deprecated is not a recognized format internally, XGBoost can guess the
# right format
xgb.Booster(model_file=path_dep)
rename(path_dep, path_ubj)
with pytest.raises(ValueError, match="Expecting"):
xgb.Booster(model_file=path_ubj)
# save model without file extension
path_no = os.path.join(tmpdir, "model")
with pytest.warns(UserWarning, match="UBJSON"):
booster.save_model(path_no)
booster_1 = xgb.Booster(model_file=path_no)
r0 = booster.save_raw(raw_format="json")
r1 = booster_1.save_raw(raw_format="json")
assert r0 == r1
def save_load_model(model_path: str) -> None:
from sklearn.datasets import load_digits