Add JSON schema to model dump. (#5660)
This commit is contained in:
parent
c42f533ae9
commit
66690f3d07
55
doc/dump.schema
Normal file
55
doc/dump.schema
Normal file
@ -0,0 +1,55 @@
|
||||
{
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"definitions": {
|
||||
"split_node": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"nodeid": {
|
||||
"type": "number",
|
||||
"minimum": 0
|
||||
},
|
||||
"depth": {
|
||||
"type": "number",
|
||||
"minimum": 0
|
||||
},
|
||||
"yes": {
|
||||
"type": "number",
|
||||
"minimum": 0
|
||||
},
|
||||
"no": {
|
||||
"type": "number",
|
||||
"minimum": 0
|
||||
},
|
||||
"split": {
|
||||
"type": "string"
|
||||
},
|
||||
"children": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"oneOf": [
|
||||
{"$ref": "#/definitions/split_node"},
|
||||
{"$ref": "#/definitions/leaf_node"}
|
||||
]
|
||||
},
|
||||
"maxItems": 2
|
||||
}
|
||||
},
|
||||
"required": ["nodeid", "depth", "yes", "no", "split", "children"]
|
||||
},
|
||||
"leaf_node": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"nodeid": {
|
||||
"type": "number",
|
||||
"minimum": 0
|
||||
},
|
||||
"leaf": {
|
||||
"type": "number"
|
||||
}
|
||||
},
|
||||
"required": ["nodeid", "leaf"]
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"$ref": "#/definitions/split_node"
|
||||
}
|
||||
@ -68,20 +68,20 @@ class TreeGenerator {
|
||||
return result;
|
||||
}
|
||||
|
||||
virtual std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) {
|
||||
virtual std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) const {
|
||||
return "";
|
||||
}
|
||||
virtual std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) {
|
||||
virtual std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) const {
|
||||
return "";
|
||||
}
|
||||
virtual std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) {
|
||||
virtual std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) const {
|
||||
return "";
|
||||
}
|
||||
virtual std::string NodeStat(RegTree const& tree, int32_t nid) {
|
||||
virtual std::string NodeStat(RegTree const& tree, int32_t nid) const {
|
||||
return "";
|
||||
}
|
||||
|
||||
virtual std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
|
||||
virtual std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const = 0;
|
||||
|
||||
virtual std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) {
|
||||
auto const split_index = tree[nid].SplitIndex();
|
||||
@ -110,7 +110,7 @@ class TreeGenerator {
|
||||
return result;
|
||||
}
|
||||
|
||||
virtual std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
|
||||
virtual std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const = 0;
|
||||
virtual std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
|
||||
|
||||
public:
|
||||
@ -181,7 +181,7 @@ class TextGenerator : public TreeGenerator {
|
||||
TextGenerator(FeatureMap const& fmap, std::string const& attrs, bool with_stats) :
|
||||
TreeGenerator(fmap, with_stats) {}
|
||||
|
||||
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||
static std::string kLeafTemplate = "{tabs}{nid}:leaf={leaf}{stats}";
|
||||
static std::string kStatTemplate = ",cover={cover}";
|
||||
std::string result = SuperT::Match(
|
||||
@ -195,7 +195,7 @@ class TextGenerator : public TreeGenerator {
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||
static std::string const kIndicatorTemplate = "{nid}:[{fname}] yes={yes},no={no}";
|
||||
int32_t nyes = tree[nid].DefaultLeft() ?
|
||||
tree[nid].RightChild() : tree[nid].LeftChild();
|
||||
@ -211,7 +211,7 @@ class TextGenerator : public TreeGenerator {
|
||||
|
||||
std::string SplitNodeImpl(
|
||||
RegTree const& tree, int32_t nid, std::string const& template_str,
|
||||
std::string cond, uint32_t depth) {
|
||||
std::string cond, uint32_t depth) const {
|
||||
auto split_index = tree[nid].SplitIndex();
|
||||
std::string const result = SuperT::Match(
|
||||
template_str,
|
||||
@ -226,7 +226,7 @@ class TextGenerator : public TreeGenerator {
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||
static std::string const kIntegerTemplate =
|
||||
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
|
||||
auto cond = tree[nid].SplitCond();
|
||||
@ -238,21 +238,21 @@ class TextGenerator : public TreeGenerator {
|
||||
std::to_string(integer_threshold), depth);
|
||||
}
|
||||
|
||||
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||
static std::string const kQuantitiveTemplate =
|
||||
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
|
||||
auto cond = tree[nid].SplitCond();
|
||||
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
|
||||
}
|
||||
|
||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||
auto cond = tree[nid].SplitCond();
|
||||
static std::string const kNodeTemplate =
|
||||
"{tabs}{nid}:[f{fname}<{cond}] yes={left},no={right},missing={missing}";
|
||||
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
|
||||
}
|
||||
|
||||
std::string NodeStat(RegTree const& tree, int32_t nid) override {
|
||||
std::string NodeStat(RegTree const& tree, int32_t nid) const override {
|
||||
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
|
||||
std::string const result = SuperT::Match(
|
||||
kStatTemplate,
|
||||
@ -297,7 +297,7 @@ class JsonGenerator : public TreeGenerator {
|
||||
JsonGenerator(FeatureMap const& fmap, std::string attrs, bool with_stats) :
|
||||
TreeGenerator(fmap, with_stats) {}
|
||||
|
||||
std::string Indent(uint32_t depth) {
|
||||
std::string Indent(uint32_t depth) const {
|
||||
std::string result;
|
||||
for (uint32_t i = 0; i < depth + 1; ++i) {
|
||||
result += " ";
|
||||
@ -305,7 +305,7 @@ class JsonGenerator : public TreeGenerator {
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||
static std::string const kLeafTemplate =
|
||||
R"L({ "nodeid": {nid}, "leaf": {leaf} {stat}})L";
|
||||
static std::string const kStatTemplate =
|
||||
@ -321,11 +321,11 @@ class JsonGenerator : public TreeGenerator {
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||
int32_t nyes = tree[nid].DefaultLeft() ?
|
||||
tree[nid].RightChild() : tree[nid].LeftChild();
|
||||
static std::string const kIndicatorTemplate =
|
||||
R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no}})ID";
|
||||
R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no})ID";
|
||||
auto split_index = tree[nid].SplitIndex();
|
||||
auto result = SuperT::Match(
|
||||
kIndicatorTemplate,
|
||||
@ -337,8 +337,9 @@ class JsonGenerator : public TreeGenerator {
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string SplitNodeImpl(RegTree const& tree, int32_t nid,
|
||||
std::string const& template_str, std::string cond, uint32_t depth) {
|
||||
std::string SplitNodeImpl(RegTree const &tree, int32_t nid,
|
||||
std::string const &template_str, std::string cond,
|
||||
uint32_t depth) const {
|
||||
auto split_index = tree[nid].SplitIndex();
|
||||
std::string const result = SuperT::Match(
|
||||
template_str,
|
||||
@ -353,7 +354,7 @@ class JsonGenerator : public TreeGenerator {
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||
auto cond = tree[nid].SplitCond();
|
||||
const bst_float floored = std::floor(cond);
|
||||
const int32_t integer_threshold
|
||||
@ -367,7 +368,7 @@ class JsonGenerator : public TreeGenerator {
|
||||
std::to_string(integer_threshold), depth);
|
||||
}
|
||||
|
||||
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||
static std::string const kQuantitiveTemplate =
|
||||
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
|
||||
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
|
||||
@ -376,7 +377,7 @@ class JsonGenerator : public TreeGenerator {
|
||||
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
|
||||
}
|
||||
|
||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||
auto cond = tree[nid].SplitCond();
|
||||
static std::string const kNodeTemplate =
|
||||
R"I( "nodeid": {nid}, "depth": {depth}, "split": {fname}, )I"
|
||||
@ -385,7 +386,7 @@ class JsonGenerator : public TreeGenerator {
|
||||
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
|
||||
}
|
||||
|
||||
std::string NodeStat(RegTree const& tree, int32_t nid) override {
|
||||
std::string NodeStat(RegTree const& tree, int32_t nid) const override {
|
||||
static std::string kStatTemplate =
|
||||
R"S(, "gain": {loss_chg}, "cover": {sum_hess})S";
|
||||
auto result = SuperT::Match(
|
||||
@ -529,7 +530,7 @@ class GraphvizGenerator : public TreeGenerator {
|
||||
protected:
|
||||
// Only indicator is different, so we combine all different node types into this
|
||||
// function.
|
||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||
auto split = tree[nid].SplitIndex();
|
||||
auto cond = tree[nid].SplitCond();
|
||||
static std::string const kNodeTemplate =
|
||||
@ -563,7 +564,7 @@ class GraphvizGenerator : public TreeGenerator {
|
||||
return result;
|
||||
};
|
||||
|
||||
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||
static std::string const kLeafTemplate =
|
||||
" {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
|
||||
auto result = SuperT::Match(kLeafTemplate, {
|
||||
|
||||
@ -151,6 +151,10 @@ TEST(Tree, DumpJson) {
|
||||
|
||||
str = tree.DumpModel(fmap, false, "json");
|
||||
ASSERT_EQ(str.find("cover"), std::string::npos);
|
||||
|
||||
|
||||
auto j_tree = Json::Load({str.c_str(), str.size()});
|
||||
ASSERT_EQ(get<Array>(j_tree["children"]).size(), 2);
|
||||
}
|
||||
|
||||
TEST(Tree, DumpText) {
|
||||
|
||||
@ -325,7 +325,7 @@ class TestModels(unittest.TestCase):
|
||||
assert locale.getpreferredencoding(False) == loc
|
||||
|
||||
@pytest.mark.skipif(**tm.no_json_schema())
|
||||
def test_json_schema(self):
|
||||
def test_json_io_schema(self):
|
||||
import jsonschema
|
||||
model_path = 'test_json_schema.json'
|
||||
path = os.path.dirname(
|
||||
@ -342,3 +342,35 @@ class TestModels(unittest.TestCase):
|
||||
jsonschema.validate(instance=json_model(model_path, parameters),
|
||||
schema=schema)
|
||||
os.remove(model_path)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_json_schema())
|
||||
def test_json_dump_schema(self):
|
||||
import jsonschema
|
||||
|
||||
def validate_model(parameters):
|
||||
X = np.random.random((100, 30))
|
||||
y = np.random.randint(0, 4, size=(100,))
|
||||
|
||||
parameters['num_class'] = 4
|
||||
m = xgb.DMatrix(X, y)
|
||||
|
||||
booster = xgb.train(parameters, m)
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
|
||||
for i in range(len(dump)):
|
||||
jsonschema.validate(instance=json.loads(dump[i]),
|
||||
schema=schema)
|
||||
|
||||
path = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
doc = os.path.join(path, 'doc', 'dump.schema')
|
||||
with open(doc, 'r') as fd:
|
||||
schema = json.load(fd)
|
||||
|
||||
parameters = {'tree_method': 'hist', 'booster': 'gbtree',
|
||||
'objective': 'multi:softmax'}
|
||||
validate_model(parameters)
|
||||
|
||||
parameters = {'tree_method': 'hist', 'booster': 'dart',
|
||||
'objective': 'multi:softmax'}
|
||||
validate_model(parameters)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user