Handle special characters in JSON model dump. (#9474)
This commit is contained in:
@@ -439,6 +439,26 @@ class TestModels:
|
||||
'objective': 'multi:softmax'}
|
||||
validate_model(parameters)
|
||||
|
||||
def test_special_model_dump_characters(self):
|
||||
params = {"objective": "reg:squarederror", "max_depth": 3}
|
||||
feature_names = ['"feature 0"', "\tfeature\n1", "feature 2"]
|
||||
X, y, w = tm.make_regression(n_samples=128, n_features=3, use_cupy=False)
|
||||
Xy = xgb.DMatrix(X, label=y, feature_names=feature_names)
|
||||
booster = xgb.train(params, Xy, num_boost_round=3)
|
||||
json_dump = booster.get_dump(dump_format="json")
|
||||
assert len(json_dump) == 3
|
||||
|
||||
def validate(obj: dict) -> None:
|
||||
for k, v in obj.items():
|
||||
if k == "split":
|
||||
assert v in feature_names
|
||||
elif isinstance(v, dict):
|
||||
validate(v)
|
||||
|
||||
for j_tree in json_dump:
|
||||
loaded = json.loads(j_tree)
|
||||
validate(loaded)
|
||||
|
||||
def test_categorical_model_io(self):
|
||||
X, y = tm.make_categorical(256, 16, 71, False)
|
||||
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||
|
||||
Reference in New Issue
Block a user