Handle special characters in JSON model dump. (#9474)

This commit is contained in:
Jiaming Yuan
2023-08-14 15:49:00 +08:00
committed by GitHub
parent f03463c45b
commit 05d7000096
7 changed files with 127 additions and 103 deletions

View File

@@ -439,6 +439,26 @@ class TestModels:
'objective': 'multi:softmax'}
validate_model(parameters)
def test_special_model_dump_characters(self):
params = {"objective": "reg:squarederror", "max_depth": 3}
feature_names = ['"feature 0"', "\tfeature\n1", "feature 2"]
X, y, w = tm.make_regression(n_samples=128, n_features=3, use_cupy=False)
Xy = xgb.DMatrix(X, label=y, feature_names=feature_names)
booster = xgb.train(params, Xy, num_boost_round=3)
json_dump = booster.get_dump(dump_format="json")
assert len(json_dump) == 3
def validate(obj: dict) -> None:
for k, v in obj.items():
if k == "split":
assert v in feature_names
elif isinstance(v, dict):
validate(v)
for j_tree in json_dump:
loaded = json.loads(j_tree)
validate(loaded)
def test_categorical_model_io(self):
X, y = tm.make_categorical(256, 16, 71, False)
Xy = xgb.DMatrix(X, y, enable_categorical=True)