Add JSON schema to model dump. (#5660)

This commit is contained in:
Jiaming Yuan 2020-05-15 10:18:43 +08:00 committed by GitHub
parent 2c1a439869
commit 535479e69f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 118 additions and 26 deletions

55
doc/dump.schema Normal file
View File

@ -0,0 +1,55 @@
{
"$schema": "http://json-schema.org/draft-07/schema#",
"definitions": {
"split_node": {
"type": "object",
"properties": {
"nodeid": {
"type": "number",
"minimum": 0
},
"depth": {
"type": "number",
"minimum": 0
},
"yes": {
"type": "number",
"minimum": 0
},
"no": {
"type": "number",
"minimum": 0
},
"split": {
"type": "string"
},
"children": {
"type": "array",
"items": {
"oneOf": [
{"$ref": "#/definitions/split_node"},
{"$ref": "#/definitions/leaf_node"}
]
},
"maxItems": 2
}
},
"required": ["nodeid", "depth", "yes", "no", "split", "children"]
},
"leaf_node": {
"type": "object",
"properties": {
"nodeid": {
"type": "number",
"minimum": 0
},
"leaf": {
"type": "number"
}
},
"required": ["nodeid", "leaf"]
}
},
"type": "object",
"$ref": "#/definitions/split_node"
}

View File

@ -68,20 +68,20 @@ class TreeGenerator {
return result; return result;
} }
virtual std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) { virtual std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) const {
return ""; return "";
} }
virtual std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) { virtual std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) const {
return ""; return "";
} }
virtual std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) { virtual std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) const {
return ""; return "";
} }
virtual std::string NodeStat(RegTree const& tree, int32_t nid) { virtual std::string NodeStat(RegTree const& tree, int32_t nid) const {
return ""; return "";
} }
virtual std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0; virtual std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const = 0;
virtual std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) { virtual std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) {
auto const split_index = tree[nid].SplitIndex(); auto const split_index = tree[nid].SplitIndex();
@ -110,7 +110,7 @@ class TreeGenerator {
return result; return result;
} }
virtual std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0; virtual std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const = 0;
virtual std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) = 0; virtual std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
public: public:
@ -181,7 +181,7 @@ class TextGenerator : public TreeGenerator {
TextGenerator(FeatureMap const& fmap, std::string const& attrs, bool with_stats) : TextGenerator(FeatureMap const& fmap, std::string const& attrs, bool with_stats) :
TreeGenerator(fmap, with_stats) {} TreeGenerator(fmap, with_stats) {}
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override { std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
static std::string kLeafTemplate = "{tabs}{nid}:leaf={leaf}{stats}"; static std::string kLeafTemplate = "{tabs}{nid}:leaf={leaf}{stats}";
static std::string kStatTemplate = ",cover={cover}"; static std::string kStatTemplate = ",cover={cover}";
std::string result = SuperT::Match( std::string result = SuperT::Match(
@ -195,7 +195,7 @@ class TextGenerator : public TreeGenerator {
return result; return result;
} }
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override { std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) const override {
static std::string const kIndicatorTemplate = "{nid}:[{fname}] yes={yes},no={no}"; static std::string const kIndicatorTemplate = "{nid}:[{fname}] yes={yes},no={no}";
int32_t nyes = tree[nid].DefaultLeft() ? int32_t nyes = tree[nid].DefaultLeft() ?
tree[nid].RightChild() : tree[nid].LeftChild(); tree[nid].RightChild() : tree[nid].LeftChild();
@ -211,7 +211,7 @@ class TextGenerator : public TreeGenerator {
std::string SplitNodeImpl( std::string SplitNodeImpl(
RegTree const& tree, int32_t nid, std::string const& template_str, RegTree const& tree, int32_t nid, std::string const& template_str,
std::string cond, uint32_t depth) { std::string cond, uint32_t depth) const {
auto split_index = tree[nid].SplitIndex(); auto split_index = tree[nid].SplitIndex();
std::string const result = SuperT::Match( std::string const result = SuperT::Match(
template_str, template_str,
@ -226,7 +226,7 @@ class TextGenerator : public TreeGenerator {
return result; return result;
} }
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override { std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) const override {
static std::string const kIntegerTemplate = static std::string const kIntegerTemplate =
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}"; "{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
auto cond = tree[nid].SplitCond(); auto cond = tree[nid].SplitCond();
@ -238,21 +238,21 @@ class TextGenerator : public TreeGenerator {
std::to_string(integer_threshold), depth); std::to_string(integer_threshold), depth);
} }
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override { std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) const override {
static std::string const kQuantitiveTemplate = static std::string const kQuantitiveTemplate =
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}"; "{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
auto cond = tree[nid].SplitCond(); auto cond = tree[nid].SplitCond();
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth); return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
} }
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override { std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
auto cond = tree[nid].SplitCond(); auto cond = tree[nid].SplitCond();
static std::string const kNodeTemplate = static std::string const kNodeTemplate =
"{tabs}{nid}:[f{fname}<{cond}] yes={left},no={right},missing={missing}"; "{tabs}{nid}:[f{fname}<{cond}] yes={left},no={right},missing={missing}";
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth); return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
} }
std::string NodeStat(RegTree const& tree, int32_t nid) override { std::string NodeStat(RegTree const& tree, int32_t nid) const override {
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}"; static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
std::string const result = SuperT::Match( std::string const result = SuperT::Match(
kStatTemplate, kStatTemplate,
@ -297,7 +297,7 @@ class JsonGenerator : public TreeGenerator {
JsonGenerator(FeatureMap const& fmap, std::string attrs, bool with_stats) : JsonGenerator(FeatureMap const& fmap, std::string attrs, bool with_stats) :
TreeGenerator(fmap, with_stats) {} TreeGenerator(fmap, with_stats) {}
std::string Indent(uint32_t depth) { std::string Indent(uint32_t depth) const {
std::string result; std::string result;
for (uint32_t i = 0; i < depth + 1; ++i) { for (uint32_t i = 0; i < depth + 1; ++i) {
result += " "; result += " ";
@ -305,7 +305,7 @@ class JsonGenerator : public TreeGenerator {
return result; return result;
} }
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override { std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
static std::string const kLeafTemplate = static std::string const kLeafTemplate =
R"L({ "nodeid": {nid}, "leaf": {leaf} {stat}})L"; R"L({ "nodeid": {nid}, "leaf": {leaf} {stat}})L";
static std::string const kStatTemplate = static std::string const kStatTemplate =
@ -321,11 +321,11 @@ class JsonGenerator : public TreeGenerator {
return result; return result;
} }
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override { std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) const override {
int32_t nyes = tree[nid].DefaultLeft() ? int32_t nyes = tree[nid].DefaultLeft() ?
tree[nid].RightChild() : tree[nid].LeftChild(); tree[nid].RightChild() : tree[nid].LeftChild();
static std::string const kIndicatorTemplate = static std::string const kIndicatorTemplate =
R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no}})ID"; R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no})ID";
auto split_index = tree[nid].SplitIndex(); auto split_index = tree[nid].SplitIndex();
auto result = SuperT::Match( auto result = SuperT::Match(
kIndicatorTemplate, kIndicatorTemplate,
@ -337,8 +337,9 @@ class JsonGenerator : public TreeGenerator {
return result; return result;
} }
std::string SplitNodeImpl(RegTree const& tree, int32_t nid, std::string SplitNodeImpl(RegTree const &tree, int32_t nid,
std::string const& template_str, std::string cond, uint32_t depth) { std::string const &template_str, std::string cond,
uint32_t depth) const {
auto split_index = tree[nid].SplitIndex(); auto split_index = tree[nid].SplitIndex();
std::string const result = SuperT::Match( std::string const result = SuperT::Match(
template_str, template_str,
@ -353,7 +354,7 @@ class JsonGenerator : public TreeGenerator {
return result; return result;
} }
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override { std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) const override {
auto cond = tree[nid].SplitCond(); auto cond = tree[nid].SplitCond();
const bst_float floored = std::floor(cond); const bst_float floored = std::floor(cond);
const int32_t integer_threshold const int32_t integer_threshold
@ -367,7 +368,7 @@ class JsonGenerator : public TreeGenerator {
std::to_string(integer_threshold), depth); std::to_string(integer_threshold), depth);
} }
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override { std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) const override {
static std::string const kQuantitiveTemplate = static std::string const kQuantitiveTemplate =
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I" R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I" R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
@ -376,7 +377,7 @@ class JsonGenerator : public TreeGenerator {
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth); return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
} }
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override { std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
auto cond = tree[nid].SplitCond(); auto cond = tree[nid].SplitCond();
static std::string const kNodeTemplate = static std::string const kNodeTemplate =
R"I( "nodeid": {nid}, "depth": {depth}, "split": {fname}, )I" R"I( "nodeid": {nid}, "depth": {depth}, "split": {fname}, )I"
@ -385,7 +386,7 @@ class JsonGenerator : public TreeGenerator {
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth); return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
} }
std::string NodeStat(RegTree const& tree, int32_t nid) override { std::string NodeStat(RegTree const& tree, int32_t nid) const override {
static std::string kStatTemplate = static std::string kStatTemplate =
R"S(, "gain": {loss_chg}, "cover": {sum_hess})S"; R"S(, "gain": {loss_chg}, "cover": {sum_hess})S";
auto result = SuperT::Match( auto result = SuperT::Match(
@ -529,7 +530,7 @@ class GraphvizGenerator : public TreeGenerator {
protected: protected:
// Only indicator is different, so we combine all different node types into this // Only indicator is different, so we combine all different node types into this
// function. // function.
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override { std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
auto split = tree[nid].SplitIndex(); auto split = tree[nid].SplitIndex();
auto cond = tree[nid].SplitCond(); auto cond = tree[nid].SplitCond();
static std::string const kNodeTemplate = static std::string const kNodeTemplate =
@ -563,7 +564,7 @@ class GraphvizGenerator : public TreeGenerator {
return result; return result;
}; };
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override { std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
static std::string const kLeafTemplate = static std::string const kLeafTemplate =
" {nid} [ label=\"leaf={leaf-value}\" {params}]\n"; " {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
auto result = SuperT::Match(kLeafTemplate, { auto result = SuperT::Match(kLeafTemplate, {

View File

@ -151,6 +151,10 @@ TEST(Tree, DumpJson) {
str = tree.DumpModel(fmap, false, "json"); str = tree.DumpModel(fmap, false, "json");
ASSERT_EQ(str.find("cover"), std::string::npos); ASSERT_EQ(str.find("cover"), std::string::npos);
auto j_tree = Json::Load({str.c_str(), str.size()});
ASSERT_EQ(get<Array>(j_tree["children"]).size(), 2);
} }
TEST(Tree, DumpText) { TEST(Tree, DumpText) {

View File

@ -325,7 +325,7 @@ class TestModels(unittest.TestCase):
assert locale.getpreferredencoding(False) == loc assert locale.getpreferredencoding(False) == loc
@pytest.mark.skipif(**tm.no_json_schema()) @pytest.mark.skipif(**tm.no_json_schema())
def test_json_schema(self): def test_json_io_schema(self):
import jsonschema import jsonschema
model_path = 'test_json_schema.json' model_path = 'test_json_schema.json'
path = os.path.dirname( path = os.path.dirname(
@ -342,3 +342,35 @@ class TestModels(unittest.TestCase):
jsonschema.validate(instance=json_model(model_path, parameters), jsonschema.validate(instance=json_model(model_path, parameters),
schema=schema) schema=schema)
os.remove(model_path) os.remove(model_path)
@pytest.mark.skipif(**tm.no_json_schema())
def test_json_dump_schema(self):
import jsonschema
def validate_model(parameters):
X = np.random.random((100, 30))
y = np.random.randint(0, 4, size=(100,))
parameters['num_class'] = 4
m = xgb.DMatrix(X, y)
booster = xgb.train(parameters, m)
dump = booster.get_dump(dump_format='json')
for i in range(len(dump)):
jsonschema.validate(instance=json.loads(dump[i]),
schema=schema)
path = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
doc = os.path.join(path, 'doc', 'dump.schema')
with open(doc, 'r') as fd:
schema = json.load(fd)
parameters = {'tree_method': 'hist', 'booster': 'gbtree',
'objective': 'multi:softmax'}
validate_model(parameters)
parameters = {'tree_method': 'hist', 'booster': 'dart',
'objective': 'multi:softmax'}
validate_model(parameters)