Add JSON schema to model dump. (#5660)

This commit is contained in:
Jiaming Yuan
2020-05-15 10:18:43 +08:00
committed by GitHub
parent 2c1a439869
commit 535479e69f
4 changed files with 118 additions and 26 deletions

View File

@@ -151,6 +151,10 @@ TEST(Tree, DumpJson) {
str = tree.DumpModel(fmap, false, "json");
ASSERT_EQ(str.find("cover"), std::string::npos);
auto j_tree = Json::Load({str.c_str(), str.size()});
ASSERT_EQ(get<Array>(j_tree["children"]).size(), 2);
}
TEST(Tree, DumpText) {

View File

@@ -325,7 +325,7 @@ class TestModels(unittest.TestCase):
assert locale.getpreferredencoding(False) == loc
@pytest.mark.skipif(**tm.no_json_schema())
def test_json_schema(self):
def test_json_io_schema(self):
import jsonschema
model_path = 'test_json_schema.json'
path = os.path.dirname(
@@ -342,3 +342,35 @@ class TestModels(unittest.TestCase):
jsonschema.validate(instance=json_model(model_path, parameters),
schema=schema)
os.remove(model_path)
@pytest.mark.skipif(**tm.no_json_schema())
def test_json_dump_schema(self):
import jsonschema
def validate_model(parameters):
X = np.random.random((100, 30))
y = np.random.randint(0, 4, size=(100,))
parameters['num_class'] = 4
m = xgb.DMatrix(X, y)
booster = xgb.train(parameters, m)
dump = booster.get_dump(dump_format='json')
for i in range(len(dump)):
jsonschema.validate(instance=json.loads(dump[i]),
schema=schema)
path = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
doc = os.path.join(path, 'doc', 'dump.schema')
with open(doc, 'r') as fd:
schema = json.load(fd)
parameters = {'tree_method': 'hist', 'booster': 'gbtree',
'objective': 'multi:softmax'}
validate_model(parameters)
parameters = {'tree_method': 'hist', 'booster': 'dart',
'objective': 'multi:softmax'}
validate_model(parameters)