Test loading models with invalid file extensions. (#9955)
This commit is contained in:
@@ -254,6 +254,68 @@ class TestBoosterIO:
|
||||
# remove file
|
||||
Path.unlink(save_path)
|
||||
|
||||
def test_invalid_postfix(self) -> None:
|
||||
"""Test mis-specified model format, no special hanlding is expected, the
|
||||
JSON/UBJ parser can emit parsing errors.
|
||||
|
||||
"""
|
||||
X, y, w = tm.make_regression(64, 16, False)
|
||||
booster = xgb.train({}, xgb.QuantileDMatrix(X, y, weight=w), num_boost_round=3)
|
||||
|
||||
def rename(src: str, dst: str) -> None:
|
||||
if os.path.exists(dst):
|
||||
# Windows cannot overwrite an existing file.
|
||||
os.remove(dst)
|
||||
os.rename(src, dst)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path_dep = os.path.join(tmpdir, "model.deprecated")
|
||||
# save into deprecated format
|
||||
with pytest.warns(UserWarning, match="UBJSON"):
|
||||
booster.save_model(path_dep)
|
||||
|
||||
path_ubj = os.path.join(tmpdir, "model.ubj")
|
||||
rename(path_dep, path_ubj)
|
||||
|
||||
with pytest.raises(ValueError, match="{"):
|
||||
xgb.Booster(model_file=path_ubj)
|
||||
|
||||
path_json = os.path.join(tmpdir, "model.json")
|
||||
rename(path_ubj, path_json)
|
||||
|
||||
with pytest.raises(ValueError, match="{"):
|
||||
xgb.Booster(model_file=path_json)
|
||||
|
||||
# save into ubj format
|
||||
booster.save_model(path_ubj)
|
||||
rename(path_ubj, path_dep)
|
||||
# deprecated is not a recognized format internally, XGBoost can guess the
|
||||
# right format
|
||||
xgb.Booster(model_file=path_dep)
|
||||
rename(path_dep, path_json)
|
||||
with pytest.raises(ValueError, match="Expecting"):
|
||||
xgb.Booster(model_file=path_json)
|
||||
|
||||
# save into JSON format
|
||||
booster.save_model(path_json)
|
||||
rename(path_json, path_dep)
|
||||
# deprecated is not a recognized format internally, XGBoost can guess the
|
||||
# right format
|
||||
xgb.Booster(model_file=path_dep)
|
||||
rename(path_dep, path_ubj)
|
||||
with pytest.raises(ValueError, match="Expecting"):
|
||||
xgb.Booster(model_file=path_ubj)
|
||||
|
||||
# save model without file extension
|
||||
path_no = os.path.join(tmpdir, "model")
|
||||
with pytest.warns(UserWarning, match="UBJSON"):
|
||||
booster.save_model(path_no)
|
||||
|
||||
booster_1 = xgb.Booster(model_file=path_no)
|
||||
r0 = booster.save_raw(raw_format="json")
|
||||
r1 = booster_1.save_raw(raw_format="json")
|
||||
assert r0 == r1
|
||||
|
||||
|
||||
def save_load_model(model_path: str) -> None:
|
||||
from sklearn.datasets import load_digits
|
||||
|
||||
Reference in New Issue
Block a user