Fix feature names with special characters. (#9923)

This commit is contained in:
Jiaming Yuan
2023-12-28 22:45:13 +08:00
committed by GitHub
parent a197899161
commit a7226c0222
4 changed files with 88 additions and 68 deletions

View File

@@ -28,10 +28,11 @@ def json_model(model_path: str, parameters: dict) -> dict:
if model_path.endswith("ubj"):
import ubjson
with open(model_path, "rb") as ubjfd:
model = ubjson.load(ubjfd)
else:
with open(model_path, 'r') as fd:
with open(model_path, "r") as fd:
model = json.load(fd)
return model
@@ -439,25 +440,34 @@ class TestModels:
'objective': 'multi:softmax'}
validate_model(parameters)
def test_special_model_dump_characters(self):
def test_special_model_dump_characters(self) -> None:
params = {"objective": "reg:squarederror", "max_depth": 3}
feature_names = ['"feature 0"', "\tfeature\n1", "feature 2"]
feature_names = ['"feature 0"', "\tfeature\n1", """feature "2"."""]
X, y, w = tm.make_regression(n_samples=128, n_features=3, use_cupy=False)
Xy = xgb.DMatrix(X, label=y, feature_names=feature_names)
booster = xgb.train(params, Xy, num_boost_round=3)
json_dump = booster.get_dump(dump_format="json")
assert len(json_dump) == 3
def validate(obj: dict) -> None:
def validate_json(obj: dict) -> None:
for k, v in obj.items():
if k == "split":
assert v in feature_names
elif isinstance(v, dict):
validate(v)
validate_json(v)
for j_tree in json_dump:
loaded = json.loads(j_tree)
validate(loaded)
validate_json(loaded)
dot_dump = booster.get_dump(dump_format="dot")
for d in dot_dump:
assert d.find(r"feature \"2\"") != -1
text_dump = booster.get_dump(dump_format="text")
for d in text_dump:
assert d.find(r"feature \"2\"") != -1
def test_categorical_model_io(self):
X, y = tm.make_categorical(256, 16, 71, False)
@@ -485,6 +495,7 @@ class TestModels:
@pytest.mark.skipif(**tm.no_sklearn())
def test_attributes(self):
from sklearn.datasets import load_iris
X, y = load_iris(return_X_y=True)
cls = xgb.XGBClassifier(n_estimators=2)
cls.fit(X, y, early_stopping_rounds=1, eval_set=[(X, y)])
@@ -674,6 +685,7 @@ class TestModels:
@pytest.mark.skipif(**tm.no_pandas())
def test_feature_info(self):
import pandas as pd
rows = 100
cols = 10
X = rng.randn(rows, cols)