diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 35c052136..3b7cb9daa 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -552,20 +552,23 @@ class GraphvizGenerator : public TreeGenerator { {"{params}", param_.condition_node_params}}); static std::string const kEdgeTemplate = - " {nid} -> {child} [label=\"{is_missing}\" color=\"{color}\"]\n"; + " {nid} -> {child} [label=\"{branch}\" color=\"{color}\"]\n"; auto MatchFn = SuperT::Match; // mingw failed to capture protected fn. auto BuildEdge = - [&tree, nid, MatchFn, this](int32_t child) { + [&tree, nid, MatchFn, this](int32_t child, bool left) { + // Is this the default child for missing value? bool is_missing = tree[nid].DefaultChild() == child; + std::string branch = std::string {left ? "yes" : "no"} + + std::string {is_missing ? ", missing" : ""}; std::string buffer = MatchFn(kEdgeTemplate, { {"{nid}", std::to_string(nid)}, {"{child}", std::to_string(child)}, {"{color}", is_missing ? param_.yes_color : param_.no_color}, - {"{is_missing}", is_missing ? "yes, missing": "no"}}); + {"{branch}", branch}}); return buffer; }; - result += BuildEdge(tree[nid].LeftChild()); - result += BuildEdge(tree[nid].RightChild()); + result += BuildEdge(tree[nid].LeftChild(), true); + result += BuildEdge(tree[nid].RightChild(), false); return result; }; diff --git a/tests/cpp/tree/test_tree_model.cc b/tests/cpp/tree/test_tree_model.cc index 83be20762..ac87b25bc 100644 --- a/tests/cpp/tree/test_tree_model.cc +++ b/tests/cpp/tree/test_tree_model.cc @@ -343,6 +343,11 @@ TEST(Tree, DumpDot) { str = tree.DumpModel(fmap, true, R"(dot:{"graph_attrs": {"bgcolor": "#FFFF00"}})"); ASSERT_NE(str.find(R"(graph [ bgcolor="#FFFF00" ])"), std::string::npos); + + // Default left for root. + ASSERT_NE(str.find(R"(0 -> 1 [label="yes, missing")"), std::string::npos); + // Default right for node 1 + ASSERT_NE(str.find(R"(1 -> 4 [label="no, missing")"), std::string::npos); } TEST(Tree, JsonIO) {