diff --git a/doc/dump.schema b/doc/dump.schema new file mode 100644 index 000000000..cb2c61be3 --- /dev/null +++ b/doc/dump.schema @@ -0,0 +1,55 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "definitions": { + "split_node": { + "type": "object", + "properties": { + "nodeid": { + "type": "number", + "minimum": 0 + }, + "depth": { + "type": "number", + "minimum": 0 + }, + "yes": { + "type": "number", + "minimum": 0 + }, + "no": { + "type": "number", + "minimum": 0 + }, + "split": { + "type": "string" + }, + "children": { + "type": "array", + "items": { + "oneOf": [ + {"$ref": "#/definitions/split_node"}, + {"$ref": "#/definitions/leaf_node"} + ] + }, + "maxItems": 2 + } + }, + "required": ["nodeid", "depth", "yes", "no", "split", "children"] + }, + "leaf_node": { + "type": "object", + "properties": { + "nodeid": { + "type": "number", + "minimum": 0 + }, + "leaf": { + "type": "number" + } + }, + "required": ["nodeid", "leaf"] + } + }, + "type": "object", + "$ref": "#/definitions/split_node" +} diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index e8046d109..8f45621ca 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -68,20 +68,20 @@ class TreeGenerator { return result; } - virtual std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) { + virtual std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) const { return ""; } - virtual std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) { + virtual std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) const { return ""; } - virtual std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) { + virtual std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) const { return ""; } - virtual std::string NodeStat(RegTree const& tree, int32_t nid) { + virtual std::string NodeStat(RegTree const& tree, int32_t nid) const { return ""; } - virtual std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0; + virtual std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const = 0; virtual std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) { auto const split_index = tree[nid].SplitIndex(); @@ -110,7 +110,7 @@ class TreeGenerator { return result; } - virtual std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0; + virtual std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const = 0; virtual std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) = 0; public: @@ -181,7 +181,7 @@ class TextGenerator : public TreeGenerator { TextGenerator(FeatureMap const& fmap, std::string const& attrs, bool with_stats) : TreeGenerator(fmap, with_stats) {} - std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override { + std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const override { static std::string kLeafTemplate = "{tabs}{nid}:leaf={leaf}{stats}"; static std::string kStatTemplate = ",cover={cover}"; std::string result = SuperT::Match( @@ -195,7 +195,7 @@ class TextGenerator : public TreeGenerator { return result; } - std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override { + std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) const override { static std::string const kIndicatorTemplate = "{nid}:[{fname}] yes={yes},no={no}"; int32_t nyes = tree[nid].DefaultLeft() ? tree[nid].RightChild() : tree[nid].LeftChild(); @@ -211,7 +211,7 @@ class TextGenerator : public TreeGenerator { std::string SplitNodeImpl( RegTree const& tree, int32_t nid, std::string const& template_str, - std::string cond, uint32_t depth) { + std::string cond, uint32_t depth) const { auto split_index = tree[nid].SplitIndex(); std::string const result = SuperT::Match( template_str, @@ -226,7 +226,7 @@ class TextGenerator : public TreeGenerator { return result; } - std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override { + std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) const override { static std::string const kIntegerTemplate = "{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}"; auto cond = tree[nid].SplitCond(); @@ -238,21 +238,21 @@ class TextGenerator : public TreeGenerator { std::to_string(integer_threshold), depth); } - std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override { + std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) const override { static std::string const kQuantitiveTemplate = "{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}"; auto cond = tree[nid].SplitCond(); return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth); } - std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override { + std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override { auto cond = tree[nid].SplitCond(); static std::string const kNodeTemplate = "{tabs}{nid}:[f{fname}<{cond}] yes={left},no={right},missing={missing}"; return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth); } - std::string NodeStat(RegTree const& tree, int32_t nid) override { + std::string NodeStat(RegTree const& tree, int32_t nid) const override { static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}"; std::string const result = SuperT::Match( kStatTemplate, @@ -297,7 +297,7 @@ class JsonGenerator : public TreeGenerator { JsonGenerator(FeatureMap const& fmap, std::string attrs, bool with_stats) : TreeGenerator(fmap, with_stats) {} - std::string Indent(uint32_t depth) { + std::string Indent(uint32_t depth) const { std::string result; for (uint32_t i = 0; i < depth + 1; ++i) { result += " "; @@ -305,7 +305,7 @@ class JsonGenerator : public TreeGenerator { return result; } - std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override { + std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const override { static std::string const kLeafTemplate = R"L({ "nodeid": {nid}, "leaf": {leaf} {stat}})L"; static std::string const kStatTemplate = @@ -321,11 +321,11 @@ class JsonGenerator : public TreeGenerator { return result; } - std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override { + std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) const override { int32_t nyes = tree[nid].DefaultLeft() ? tree[nid].RightChild() : tree[nid].LeftChild(); static std::string const kIndicatorTemplate = - R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no}})ID"; + R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no})ID"; auto split_index = tree[nid].SplitIndex(); auto result = SuperT::Match( kIndicatorTemplate, @@ -337,8 +337,9 @@ class JsonGenerator : public TreeGenerator { return result; } - std::string SplitNodeImpl(RegTree const& tree, int32_t nid, - std::string const& template_str, std::string cond, uint32_t depth) { + std::string SplitNodeImpl(RegTree const &tree, int32_t nid, + std::string const &template_str, std::string cond, + uint32_t depth) const { auto split_index = tree[nid].SplitIndex(); std::string const result = SuperT::Match( template_str, @@ -353,7 +354,7 @@ class JsonGenerator : public TreeGenerator { return result; } - std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override { + std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) const override { auto cond = tree[nid].SplitCond(); const bst_float floored = std::floor(cond); const int32_t integer_threshold @@ -367,7 +368,7 @@ class JsonGenerator : public TreeGenerator { std::to_string(integer_threshold), depth); } - std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override { + std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) const override { static std::string const kQuantitiveTemplate = R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I" R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I" @@ -376,7 +377,7 @@ class JsonGenerator : public TreeGenerator { return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth); } - std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override { + std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override { auto cond = tree[nid].SplitCond(); static std::string const kNodeTemplate = R"I( "nodeid": {nid}, "depth": {depth}, "split": {fname}, )I" @@ -385,7 +386,7 @@ class JsonGenerator : public TreeGenerator { return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth); } - std::string NodeStat(RegTree const& tree, int32_t nid) override { + std::string NodeStat(RegTree const& tree, int32_t nid) const override { static std::string kStatTemplate = R"S(, "gain": {loss_chg}, "cover": {sum_hess})S"; auto result = SuperT::Match( @@ -529,7 +530,7 @@ class GraphvizGenerator : public TreeGenerator { protected: // Only indicator is different, so we combine all different node types into this // function. - std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override { + std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override { auto split = tree[nid].SplitIndex(); auto cond = tree[nid].SplitCond(); static std::string const kNodeTemplate = @@ -563,7 +564,7 @@ class GraphvizGenerator : public TreeGenerator { return result; }; - std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override { + std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const override { static std::string const kLeafTemplate = " {nid} [ label=\"leaf={leaf-value}\" {params}]\n"; auto result = SuperT::Match(kLeafTemplate, { diff --git a/tests/cpp/tree/test_tree_model.cc b/tests/cpp/tree/test_tree_model.cc index 406e9e62f..dbf2b80a2 100644 --- a/tests/cpp/tree/test_tree_model.cc +++ b/tests/cpp/tree/test_tree_model.cc @@ -151,6 +151,10 @@ TEST(Tree, DumpJson) { str = tree.DumpModel(fmap, false, "json"); ASSERT_EQ(str.find("cover"), std::string::npos); + + + auto j_tree = Json::Load({str.c_str(), str.size()}); + ASSERT_EQ(get(j_tree["children"]).size(), 2); } TEST(Tree, DumpText) { diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index c94012b2a..9b49c5360 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -325,7 +325,7 @@ class TestModels(unittest.TestCase): assert locale.getpreferredencoding(False) == loc @pytest.mark.skipif(**tm.no_json_schema()) - def test_json_schema(self): + def test_json_io_schema(self): import jsonschema model_path = 'test_json_schema.json' path = os.path.dirname( @@ -342,3 +342,35 @@ class TestModels(unittest.TestCase): jsonschema.validate(instance=json_model(model_path, parameters), schema=schema) os.remove(model_path) + + @pytest.mark.skipif(**tm.no_json_schema()) + def test_json_dump_schema(self): + import jsonschema + + def validate_model(parameters): + X = np.random.random((100, 30)) + y = np.random.randint(0, 4, size=(100,)) + + parameters['num_class'] = 4 + m = xgb.DMatrix(X, y) + + booster = xgb.train(parameters, m) + dump = booster.get_dump(dump_format='json') + + for i in range(len(dump)): + jsonschema.validate(instance=json.loads(dump[i]), + schema=schema) + + path = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + doc = os.path.join(path, 'doc', 'dump.schema') + with open(doc, 'r') as fd: + schema = json.load(fd) + + parameters = {'tree_method': 'hist', 'booster': 'gbtree', + 'objective': 'multi:softmax'} + validate_model(parameters) + + parameters = {'tree_method': 'hist', 'booster': 'dart', + 'objective': 'multi:softmax'} + validate_model(parameters)