Add JSON schema to model dump. (#5660)
This commit is contained in:
parent
2c1a439869
commit
535479e69f
55
doc/dump.schema
Normal file
55
doc/dump.schema
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
{
|
||||||
|
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||||
|
"definitions": {
|
||||||
|
"split_node": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"nodeid": {
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
|
"depth": {
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
|
"yes": {
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
|
"no": {
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
|
"split": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"children": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"oneOf": [
|
||||||
|
{"$ref": "#/definitions/split_node"},
|
||||||
|
{"$ref": "#/definitions/leaf_node"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"maxItems": 2
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["nodeid", "depth", "yes", "no", "split", "children"]
|
||||||
|
},
|
||||||
|
"leaf_node": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"nodeid": {
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
|
"leaf": {
|
||||||
|
"type": "number"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["nodeid", "leaf"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
"$ref": "#/definitions/split_node"
|
||||||
|
}
|
||||||
@ -68,20 +68,20 @@ class TreeGenerator {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) {
|
virtual std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) const {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
virtual std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) {
|
virtual std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) const {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
virtual std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) {
|
virtual std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) const {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
virtual std::string NodeStat(RegTree const& tree, int32_t nid) {
|
virtual std::string NodeStat(RegTree const& tree, int32_t nid) const {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
|
virtual std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const = 0;
|
||||||
|
|
||||||
virtual std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) {
|
virtual std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) {
|
||||||
auto const split_index = tree[nid].SplitIndex();
|
auto const split_index = tree[nid].SplitIndex();
|
||||||
@ -110,7 +110,7 @@ class TreeGenerator {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
|
virtual std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const = 0;
|
||||||
virtual std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
|
virtual std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@ -181,7 +181,7 @@ class TextGenerator : public TreeGenerator {
|
|||||||
TextGenerator(FeatureMap const& fmap, std::string const& attrs, bool with_stats) :
|
TextGenerator(FeatureMap const& fmap, std::string const& attrs, bool with_stats) :
|
||||||
TreeGenerator(fmap, with_stats) {}
|
TreeGenerator(fmap, with_stats) {}
|
||||||
|
|
||||||
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||||
static std::string kLeafTemplate = "{tabs}{nid}:leaf={leaf}{stats}";
|
static std::string kLeafTemplate = "{tabs}{nid}:leaf={leaf}{stats}";
|
||||||
static std::string kStatTemplate = ",cover={cover}";
|
static std::string kStatTemplate = ",cover={cover}";
|
||||||
std::string result = SuperT::Match(
|
std::string result = SuperT::Match(
|
||||||
@ -195,7 +195,7 @@ class TextGenerator : public TreeGenerator {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||||
static std::string const kIndicatorTemplate = "{nid}:[{fname}] yes={yes},no={no}";
|
static std::string const kIndicatorTemplate = "{nid}:[{fname}] yes={yes},no={no}";
|
||||||
int32_t nyes = tree[nid].DefaultLeft() ?
|
int32_t nyes = tree[nid].DefaultLeft() ?
|
||||||
tree[nid].RightChild() : tree[nid].LeftChild();
|
tree[nid].RightChild() : tree[nid].LeftChild();
|
||||||
@ -211,7 +211,7 @@ class TextGenerator : public TreeGenerator {
|
|||||||
|
|
||||||
std::string SplitNodeImpl(
|
std::string SplitNodeImpl(
|
||||||
RegTree const& tree, int32_t nid, std::string const& template_str,
|
RegTree const& tree, int32_t nid, std::string const& template_str,
|
||||||
std::string cond, uint32_t depth) {
|
std::string cond, uint32_t depth) const {
|
||||||
auto split_index = tree[nid].SplitIndex();
|
auto split_index = tree[nid].SplitIndex();
|
||||||
std::string const result = SuperT::Match(
|
std::string const result = SuperT::Match(
|
||||||
template_str,
|
template_str,
|
||||||
@ -226,7 +226,7 @@ class TextGenerator : public TreeGenerator {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||||
static std::string const kIntegerTemplate =
|
static std::string const kIntegerTemplate =
|
||||||
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
|
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
|
||||||
auto cond = tree[nid].SplitCond();
|
auto cond = tree[nid].SplitCond();
|
||||||
@ -238,21 +238,21 @@ class TextGenerator : public TreeGenerator {
|
|||||||
std::to_string(integer_threshold), depth);
|
std::to_string(integer_threshold), depth);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||||
static std::string const kQuantitiveTemplate =
|
static std::string const kQuantitiveTemplate =
|
||||||
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
|
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
|
||||||
auto cond = tree[nid].SplitCond();
|
auto cond = tree[nid].SplitCond();
|
||||||
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
|
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||||
auto cond = tree[nid].SplitCond();
|
auto cond = tree[nid].SplitCond();
|
||||||
static std::string const kNodeTemplate =
|
static std::string const kNodeTemplate =
|
||||||
"{tabs}{nid}:[f{fname}<{cond}] yes={left},no={right},missing={missing}";
|
"{tabs}{nid}:[f{fname}<{cond}] yes={left},no={right},missing={missing}";
|
||||||
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
|
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string NodeStat(RegTree const& tree, int32_t nid) override {
|
std::string NodeStat(RegTree const& tree, int32_t nid) const override {
|
||||||
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
|
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
|
||||||
std::string const result = SuperT::Match(
|
std::string const result = SuperT::Match(
|
||||||
kStatTemplate,
|
kStatTemplate,
|
||||||
@ -297,7 +297,7 @@ class JsonGenerator : public TreeGenerator {
|
|||||||
JsonGenerator(FeatureMap const& fmap, std::string attrs, bool with_stats) :
|
JsonGenerator(FeatureMap const& fmap, std::string attrs, bool with_stats) :
|
||||||
TreeGenerator(fmap, with_stats) {}
|
TreeGenerator(fmap, with_stats) {}
|
||||||
|
|
||||||
std::string Indent(uint32_t depth) {
|
std::string Indent(uint32_t depth) const {
|
||||||
std::string result;
|
std::string result;
|
||||||
for (uint32_t i = 0; i < depth + 1; ++i) {
|
for (uint32_t i = 0; i < depth + 1; ++i) {
|
||||||
result += " ";
|
result += " ";
|
||||||
@ -305,7 +305,7 @@ class JsonGenerator : public TreeGenerator {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||||
static std::string const kLeafTemplate =
|
static std::string const kLeafTemplate =
|
||||||
R"L({ "nodeid": {nid}, "leaf": {leaf} {stat}})L";
|
R"L({ "nodeid": {nid}, "leaf": {leaf} {stat}})L";
|
||||||
static std::string const kStatTemplate =
|
static std::string const kStatTemplate =
|
||||||
@ -321,11 +321,11 @@ class JsonGenerator : public TreeGenerator {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||||
int32_t nyes = tree[nid].DefaultLeft() ?
|
int32_t nyes = tree[nid].DefaultLeft() ?
|
||||||
tree[nid].RightChild() : tree[nid].LeftChild();
|
tree[nid].RightChild() : tree[nid].LeftChild();
|
||||||
static std::string const kIndicatorTemplate =
|
static std::string const kIndicatorTemplate =
|
||||||
R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no}})ID";
|
R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no})ID";
|
||||||
auto split_index = tree[nid].SplitIndex();
|
auto split_index = tree[nid].SplitIndex();
|
||||||
auto result = SuperT::Match(
|
auto result = SuperT::Match(
|
||||||
kIndicatorTemplate,
|
kIndicatorTemplate,
|
||||||
@ -337,8 +337,9 @@ class JsonGenerator : public TreeGenerator {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string SplitNodeImpl(RegTree const& tree, int32_t nid,
|
std::string SplitNodeImpl(RegTree const &tree, int32_t nid,
|
||||||
std::string const& template_str, std::string cond, uint32_t depth) {
|
std::string const &template_str, std::string cond,
|
||||||
|
uint32_t depth) const {
|
||||||
auto split_index = tree[nid].SplitIndex();
|
auto split_index = tree[nid].SplitIndex();
|
||||||
std::string const result = SuperT::Match(
|
std::string const result = SuperT::Match(
|
||||||
template_str,
|
template_str,
|
||||||
@ -353,7 +354,7 @@ class JsonGenerator : public TreeGenerator {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||||
auto cond = tree[nid].SplitCond();
|
auto cond = tree[nid].SplitCond();
|
||||||
const bst_float floored = std::floor(cond);
|
const bst_float floored = std::floor(cond);
|
||||||
const int32_t integer_threshold
|
const int32_t integer_threshold
|
||||||
@ -367,7 +368,7 @@ class JsonGenerator : public TreeGenerator {
|
|||||||
std::to_string(integer_threshold), depth);
|
std::to_string(integer_threshold), depth);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||||
static std::string const kQuantitiveTemplate =
|
static std::string const kQuantitiveTemplate =
|
||||||
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
|
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
|
||||||
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
|
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
|
||||||
@ -376,7 +377,7 @@ class JsonGenerator : public TreeGenerator {
|
|||||||
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
|
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||||
auto cond = tree[nid].SplitCond();
|
auto cond = tree[nid].SplitCond();
|
||||||
static std::string const kNodeTemplate =
|
static std::string const kNodeTemplate =
|
||||||
R"I( "nodeid": {nid}, "depth": {depth}, "split": {fname}, )I"
|
R"I( "nodeid": {nid}, "depth": {depth}, "split": {fname}, )I"
|
||||||
@ -385,7 +386,7 @@ class JsonGenerator : public TreeGenerator {
|
|||||||
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
|
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string NodeStat(RegTree const& tree, int32_t nid) override {
|
std::string NodeStat(RegTree const& tree, int32_t nid) const override {
|
||||||
static std::string kStatTemplate =
|
static std::string kStatTemplate =
|
||||||
R"S(, "gain": {loss_chg}, "cover": {sum_hess})S";
|
R"S(, "gain": {loss_chg}, "cover": {sum_hess})S";
|
||||||
auto result = SuperT::Match(
|
auto result = SuperT::Match(
|
||||||
@ -529,7 +530,7 @@ class GraphvizGenerator : public TreeGenerator {
|
|||||||
protected:
|
protected:
|
||||||
// Only indicator is different, so we combine all different node types into this
|
// Only indicator is different, so we combine all different node types into this
|
||||||
// function.
|
// function.
|
||||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||||
auto split = tree[nid].SplitIndex();
|
auto split = tree[nid].SplitIndex();
|
||||||
auto cond = tree[nid].SplitCond();
|
auto cond = tree[nid].SplitCond();
|
||||||
static std::string const kNodeTemplate =
|
static std::string const kNodeTemplate =
|
||||||
@ -563,7 +564,7 @@ class GraphvizGenerator : public TreeGenerator {
|
|||||||
return result;
|
return result;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||||
static std::string const kLeafTemplate =
|
static std::string const kLeafTemplate =
|
||||||
" {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
|
" {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
|
||||||
auto result = SuperT::Match(kLeafTemplate, {
|
auto result = SuperT::Match(kLeafTemplate, {
|
||||||
|
|||||||
@ -151,6 +151,10 @@ TEST(Tree, DumpJson) {
|
|||||||
|
|
||||||
str = tree.DumpModel(fmap, false, "json");
|
str = tree.DumpModel(fmap, false, "json");
|
||||||
ASSERT_EQ(str.find("cover"), std::string::npos);
|
ASSERT_EQ(str.find("cover"), std::string::npos);
|
||||||
|
|
||||||
|
|
||||||
|
auto j_tree = Json::Load({str.c_str(), str.size()});
|
||||||
|
ASSERT_EQ(get<Array>(j_tree["children"]).size(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Tree, DumpText) {
|
TEST(Tree, DumpText) {
|
||||||
|
|||||||
@ -325,7 +325,7 @@ class TestModels(unittest.TestCase):
|
|||||||
assert locale.getpreferredencoding(False) == loc
|
assert locale.getpreferredencoding(False) == loc
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_json_schema())
|
@pytest.mark.skipif(**tm.no_json_schema())
|
||||||
def test_json_schema(self):
|
def test_json_io_schema(self):
|
||||||
import jsonschema
|
import jsonschema
|
||||||
model_path = 'test_json_schema.json'
|
model_path = 'test_json_schema.json'
|
||||||
path = os.path.dirname(
|
path = os.path.dirname(
|
||||||
@ -342,3 +342,35 @@ class TestModels(unittest.TestCase):
|
|||||||
jsonschema.validate(instance=json_model(model_path, parameters),
|
jsonschema.validate(instance=json_model(model_path, parameters),
|
||||||
schema=schema)
|
schema=schema)
|
||||||
os.remove(model_path)
|
os.remove(model_path)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_json_schema())
|
||||||
|
def test_json_dump_schema(self):
|
||||||
|
import jsonschema
|
||||||
|
|
||||||
|
def validate_model(parameters):
|
||||||
|
X = np.random.random((100, 30))
|
||||||
|
y = np.random.randint(0, 4, size=(100,))
|
||||||
|
|
||||||
|
parameters['num_class'] = 4
|
||||||
|
m = xgb.DMatrix(X, y)
|
||||||
|
|
||||||
|
booster = xgb.train(parameters, m)
|
||||||
|
dump = booster.get_dump(dump_format='json')
|
||||||
|
|
||||||
|
for i in range(len(dump)):
|
||||||
|
jsonschema.validate(instance=json.loads(dump[i]),
|
||||||
|
schema=schema)
|
||||||
|
|
||||||
|
path = os.path.dirname(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
doc = os.path.join(path, 'doc', 'dump.schema')
|
||||||
|
with open(doc, 'r') as fd:
|
||||||
|
schema = json.load(fd)
|
||||||
|
|
||||||
|
parameters = {'tree_method': 'hist', 'booster': 'gbtree',
|
||||||
|
'objective': 'multi:softmax'}
|
||||||
|
validate_model(parameters)
|
||||||
|
|
||||||
|
parameters = {'tree_method': 'hist', 'booster': 'dart',
|
||||||
|
'objective': 'multi:softmax'}
|
||||||
|
validate_model(parameters)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user