[R] xgb.plot.tree fixes (#1939)

* [R] a few fixes and improvements to xgb.plot.tree

* [R] deprecate n_first_tree replace with trees; fix types in xgb.model.dt.tree
This commit is contained in:
Vadim Khotilovich
2017-01-06 13:09:51 -06:00
committed by Tianqi Chen
parent d23ea5ca7d
commit d7406e07f3
7 changed files with 225 additions and 116 deletions

View File

@@ -2,37 +2,65 @@
#'
#' Read a tree model text dump and plot the model.
#'
#' @param feature_names names of each feature as a \code{character} vector. Can be extracted from a sparse matrix (see example). If model dump already contains feature names, this argument should be \code{NULL}.
#' @param model generated by the \code{xgb.train} function. Avoid the creation of a dump file.
#' @param n_first_tree limit the plot to the n first trees. If \code{NULL}, all trees of the model are plotted. Performance can be low for huge models.
#' @param feature_names names of each feature as a \code{character} vector.
#' @param model produced by the \code{xgb.train} function.
#' @param trees an integer vector of tree indices that should be visualized.
#' If set to \code{NULL}, all trees of the model are included.
#' IMPORTANT: the tree index in xgboost model is zero-based
#' (e.g., use \code{trees = 0:2} for the first 3 trees in a model).
#' @param plot_width the width of the diagram in pixels.
#' @param plot_height the height of the diagram in pixels.
#' @param render a logical flag for whether the graph should be rendered (see Value).
#' @param show_node_id a logical flag for whether to include node id's in the graph.
#' @param ... currently not used.
#'
#' @return A \code{DiagrammeR} of the model.
#'
#' @details
#'
#' The content of each node is organised that way:
#'
#' \itemize{
#' \item \code{feature} value;
#' \item \code{cover}: the sum of second order gradient of training data classified to the leaf, if it is square loss, this simply corresponds to the number of instances in that branch. Deeper in the tree a node is, lower this metric will be;
#' \item \code{gain}: metric the importance of the node in the model.
#' \item Feature name.
#' \item \code{Cover}: The sum of second order gradient of training data classified to the leaf.
#' If it is square loss, this simply corresponds to the number of instances seen by a split
#' or collected by a leaf during training.
#' The deeper in the tree a node is, the lower this metric will be.
#' \item \code{Gain} (for split nodes): the information gain metric of a split
#' (corresponds to the importance of the node in the model).
#' \item \code{Value} (for leafs): the margin value that the leaf may contribute to prediction.
#' }
#' The tree root nodes also indicate the Tree index (0-based).
#'
#' The function uses \href{http://www.graphviz.org/}{GraphViz} library for that purpose.
#' The "Yes" branches are marked by the "< split_value" label.
#' The branches that also used for missing values are marked as bold
#' (as in "carrying extra capacity").
#'
#' This function uses \href{http://www.graphviz.org/}{GraphViz} as a backend of DiagrammeR.
#'
#' @return
#'
#' When \code{render = TRUE}:
#' returns a rendered graph object which is an \code{htmlwidget} of class \code{grViz}.
#' Similar to ggplot objects, it needs to be printed to see it when not running from command line.
#'
#' When \code{render = FALSE}:
#' silently returns a graph object which is of DiagrammeR's class \code{dgr_graph}.
#' This could be useful if one wants to modify some of the graph attributes
#' before rendering the graph with \code{\link[DiagrammeR]{render_graph}}.
#'
#' @examples
#' data(agaricus.train, package='xgboost')
#'
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 2,
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 3,
#' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
#'
#' # plot all the trees
#' xgb.plot.tree(feature_names = colnames(agaricus.train$data), model = bst)
#' # plot only the first tree and include the node ID:
#' xgb.plot.tree(feature_names = colnames(agaricus.train$data), model = bst,
#' trees = 0, show_node_id = TRUE)
#'
#' @export
xgb.plot.tree <- function(feature_names = NULL, model = NULL, n_first_tree = NULL, plot_width = NULL, plot_height = NULL, ...){
xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot_width = NULL, plot_height = NULL,
render = TRUE, show_node_id = FALSE, ...){
check.deprecation(...)
if (class(model) != "xgb.Booster") {
stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.")
@@ -42,34 +70,55 @@ xgb.plot.tree <- function(feature_names = NULL, model = NULL, n_first_tree = NUL
stop("DiagrammeR package is required for xgb.plot.tree", call. = FALSE)
}
allTrees <- xgb.model.dt.tree(feature_names = feature_names, model = model, n_first_tree = n_first_tree)
allTrees[, label:= paste0(Feature, "\nCover: ", Cover, "\nGain: ", Quality)]
allTrees[, shape:= "rectangle"][Feature == "Leaf", shape:= "oval"]
allTrees[, filledcolor:= "Beige"][Feature == "Leaf", filledcolor:= "Khaki"]
# rev is used to put the first tree on top.
nodes <- DiagrammeR::create_node_df(n = length(allTrees[,ID] %>% rev),
label = allTrees[,label] %>% rev,
style = "filled",
color = "DimGray",
fillcolor= allTrees[,filledcolor] %>% rev,
shape = allTrees[,shape] %>% rev,
data = allTrees[,Feature] %>% rev,
fontname = "Helvetica"
)
edges <- DiagrammeR::create_edge_df(from = match(allTrees[Feature != "Leaf", c(ID)] %>% rep(2), allTrees[,ID] %>% rev),
to = match(allTrees[Feature != "Leaf", c(Yes, No)],allTrees[,ID] %>% rev),
label = allTrees[Feature != "Leaf", paste("<",Split)] %>% c(rep("",nrow(allTrees[Feature != "Leaf"]))),
color = "DimGray",
arrowsize = "1.5",
arrowhead = "vee",
fontname = "Helvetica",
rel = "leading_to")
dt <- xgb.model.dt.tree(feature_names = feature_names, model = model, trees = trees)
graph <- DiagrammeR::create_graph(nodes_df = nodes,
edges_df = edges)
dt[, label:= paste0(Feature, "\nCover: ", Cover, ifelse(Feature == "Leaf", "\nValue: ", "\nGain: "), Quality)]
if (show_node_id)
dt[, label := paste0(ID, ": ", label)]
dt[Node == 0, label := paste0("Tree ", Tree, "\n", label)]
dt[, shape:= "rectangle"][Feature == "Leaf", shape:= "oval"]
dt[, filledcolor:= "Beige"][Feature == "Leaf", filledcolor:= "Khaki"]
# in order to draw the first tree on top:
dt <- dt[order(-Tree)]
nodes <- DiagrammeR::create_node_df(
n = nrow(dt),
ID = dt$ID,
label = dt$label,
fillcolor = dt$filledcolor,
shape = dt$shape,
data = dt$Feature)
edges <- DiagrammeR::create_edge_df(
from = match(dt[Feature != "Leaf", c(ID)] %>% rep(2), dt$ID),
to = match(dt[Feature != "Leaf", c(Yes, No)], dt$ID),
label = dt[Feature != "Leaf", paste("<", Split)] %>%
c(rep("", nrow(dt[Feature != "Leaf"]))),
style = dt[Feature != "Leaf", ifelse(Missing == Yes, "bold", "solid")] %>%
c(dt[Feature != "Leaf", ifelse(Missing == No, "bold", "solid")]),
rel = "leading_to")
graph <- DiagrammeR::create_graph(
nodes_df = nodes,
edges_df = edges,
attr_theme = NULL
) %>%
DiagrammeR::add_global_graph_attrs(
attr_type = "graph",
attr = c("layout", "rankdir"),
value = c("dot", "LR")
) %>%
DiagrammeR::add_global_graph_attrs(
attr_type = "node",
attr = c("color", "style", "fontname"),
value = c("DimGray", "filled", "Helvetica")
) %>%
DiagrammeR::add_global_graph_attrs(
attr_type = "edge",
attr = c("color", "arrowsize", "arrowhead", "fontname"),
value = c("DimGray", "1.5", "vee", "Helvetica"))
if (!render) return(invisible(graph))
DiagrammeR::render_graph(graph, width = plot_width, height = plot_height)
}