[R] xgb.plot.tree fixes (#1939)

* [R] a few fixes and improvements to xgb.plot.tree

* [R] deprecate n_first_tree replace with trees; fix types in xgb.model.dt.tree
This commit is contained in:
Vadim Khotilovich 2017-01-06 13:09:51 -06:00 committed by Tianqi Chen
parent d23ea5ca7d
commit d7406e07f3
7 changed files with 225 additions and 116 deletions

View File

@ -304,6 +304,7 @@ depr_par_lut <- matrix(c(
'features.keep', 'features_keep', 'features.keep', 'features_keep',
'plot.height','plot_height', 'plot.height','plot_height',
'plot.width','plot_width', 'plot.width','plot_width',
'n_first_tree', 'trees',
'dummy', 'DUMMY' 'dummy', 'DUMMY'
), ncol=2, byrow = TRUE) ), ncol=2, byrow = TRUE)
colnames(depr_par_lut) <- c('old', 'new') colnames(depr_par_lut) <- c('old', 'new')

View File

@ -7,8 +7,12 @@
#' @param model object of class \code{xgb.Booster} #' @param model object of class \code{xgb.Booster}
#' @param text \code{character} vector previously generated by the \code{xgb.dump} #' @param text \code{character} vector previously generated by the \code{xgb.dump}
#' function (where parameter \code{with_stats = TRUE} should have been set). #' function (where parameter \code{with_stats = TRUE} should have been set).
#' @param n_first_tree limit the parsing to the \code{n} first trees. #' @param trees an integer vector of tree indices that should be parsed.
#' If set to \code{NULL}, all trees of the model are parsed. #' If set to \code{NULL}, all trees of the model are parsed.
#' It could be useful, e.g., in multiclass classification to get only
#' the trees of one certain class. IMPORTANT: the tree index in xgboost model
#' is zero-based (e.g., use \code{trees = 0:4} for first 5 trees).
#' @param ... currently not used.
#' #'
#' @return #' @return
#' A \code{data.table} with detailed information about model trees' nodes. #' A \code{data.table} with detailed information about model trees' nodes.
@ -16,9 +20,9 @@
#' The columns of the \code{data.table} are: #' The columns of the \code{data.table} are:
#' #'
#' \itemize{ #' \itemize{
#' \item \code{Tree}: ID of a tree in a model #' \item \code{Tree}: ID of a tree in a model (integer)
#' \item \code{Node}: ID of a node in a tree #' \item \code{Node}: integer ID of a node in a tree (integer)
#' \item \code{ID}: unique identifier of a node in a model #' \item \code{ID}: identifier of a node in a model (character)
#' \item \code{Feature}: for a branch node, it's a feature id or name (when available); #' \item \code{Feature}: for a branch node, it's a feature id or name (when available);
#' for a leaf note, it simply labels it as \code{'Leaf'} #' for a leaf note, it simply labels it as \code{'Leaf'}
#' \item \code{Split}: location of the split for a branch node (split condition is always "less than") #' \item \code{Split}: location of the split for a branch node (split condition is always "less than")
@ -47,8 +51,8 @@
#' #'
#' @export #' @export
xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL, xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
n_first_tree = NULL){ trees = NULL, ...){
check.deprecation(...)
if (!class(feature_names) %in% c("character", "NULL")) { if (!class(feature_names) %in% c("character", "NULL")) {
stop("feature_names: Has to be a vector of character\n", stop("feature_names: Has to be a vector of character\n",
" or NULL if the model dump already contains feature names.\n", " or NULL if the model dump already contains feature names.\n",
@ -61,8 +65,8 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
" (or NULL if the model was provided).") " (or NULL if the model was provided).")
} }
if (!class(n_first_tree) %in% c("numeric", "NULL") | length(n_first_tree) > 1) { if (!class(trees) %in% c("integer", "numeric", "NULL")) {
stop("n_first_tree: Has to be a numeric vector of size 1.") stop("trees: Has to be a vector of integers.")
} }
if (is.null(text)){ if (is.null(text)){
@ -84,10 +88,14 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
td[position, Tree := 1L] td[position, Tree := 1L]
td[, Tree := cumsum(ifelse(is.na(Tree), 0L, Tree)) - 1L] td[, Tree := cumsum(ifelse(is.na(Tree), 0L, Tree)) - 1L]
n_first_tree <- min(max(td$Tree), n_first_tree) if (is.null(trees)) {
td <- td[Tree <= n_first_tree & !grepl('^booster', t)] trees <- 0:max(td$Tree)
} else {
trees <- trees[trees >= 0 & trees <= max(td$Tree)]
}
td <- td[Tree %in% trees & !grepl('^booster', t)]
td[, Node := stri_match_first_regex(t, "(\\d+):")[,2] %>% as.numeric ] td[, Node := stri_match_first_regex(t, "(\\d+):")[,2] %>% as.integer ]
td[, ID := add.tree.id(Node, Tree)] td[, ID := add.tree.id(Node, Tree)]
td[, isLeaf := !is.na(stri_match_first_regex(t, "leaf"))] td[, isLeaf := !is.na(stri_match_first_regex(t, "leaf"))]
@ -112,7 +120,7 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
}] }]
# convert some columns to numeric # convert some columns to numeric
numeric_cols <- c("Quality", "Cover") numeric_cols <- c("Split", "Quality", "Cover")
td[, (numeric_cols) := lapply(.SD, as.numeric), .SDcols=numeric_cols] td[, (numeric_cols) := lapply(.SD, as.numeric), .SDcols=numeric_cols]
td[, t := NULL] td[, t := NULL]

View File

@ -2,8 +2,8 @@
#' #'
#' Visualization of the ensemble of trees as a single collective unit. #' Visualization of the ensemble of trees as a single collective unit.
#' #'
#' @param model dump generated by the \code{xgb.train} function. #' @param model produced by the \code{xgb.train} function.
#' @param feature_names names of each feature as a \code{character} vector. Can be extracted from a sparse matrix (see example). If model dump already contains feature names, this argument should be \code{NULL}. #' @param feature_names names of each feature as a \code{character} vector.
#' @param features_keep number of features to keep in each position of the multi trees. #' @param features_keep number of features to keep in each position of the multi trees.
#' @param plot_width width in pixels of the graph to produce #' @param plot_width width in pixels of the graph to produce
#' @param plot_height height in pixels of the graph to produce #' @param plot_height height in pixels of the graph to produce
@ -13,21 +13,19 @@
#' #'
#' @details #' @details
#' #'
#' This function tries to capture the complexity of gradient boosted tree ensemble #' This function tries to capture the complexity of a gradient boosted tree model
#' in a cohesive way. #' in a cohesive way by compressing an ensemble of trees into a single tree-graph representation.
#' The goal is to improve the interpretability of a model generally seen as black box.
#' #'
#' The goal is to improve the interpretability of the model generally seen as black box. #' Note: this function is applicable to tree booster-based models only.
#' The function is dedicated to boosting applied to decision trees only.
#'
#' The purpose is to move from an ensemble of trees to a single tree only.
#' #'
#' It takes advantage of the fact that the shape of a binary tree is only defined by #' It takes advantage of the fact that the shape of a binary tree is only defined by
#' its deepness (therefore in a boosting model, all trees have the same shape). #' its depth (therefore, in a boosting model, all trees have similar shape).
#' #'
#' Moreover, the trees tend to reuse the same features. #' Moreover, the trees tend to reuse the same features.
#' #'
#' The function will project each tree on one, and keep for each position the #' The function projects each tree onto one, and keeps for each position the
#' \code{features_keep} first features (based on Gain per feature measure). #' \code{features_keep} first features (based on the Gain per feature measure).
#' #'
#' This function is inspired by this blog post: #' This function is inspired by this blog post:
#' \url{https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/} #' \url{https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/}
@ -70,39 +68,61 @@ xgb.plot.multi.trees <- function(model, feature_names = NULL, features_keep = 5,
tree.matrix[!is.na(No),No:= paste0(abs.node.position, "_1")] tree.matrix[!is.na(No),No:= paste0(abs.node.position, "_1")]
remove.tree <- . %>% stri_replace_first_regex(pattern = "^\\d+-", replacement = "") remove.tree <- . %>% stri_replace_first_regex(pattern = "^\\d+-", replacement = "")
tree.matrix[,`:=`(abs.node.position=remove.tree(abs.node.position), Yes=remove.tree(Yes), No=remove.tree(No))] tree.matrix[,`:=`(abs.node.position = remove.tree(abs.node.position),
Yes = remove.tree(Yes),
No = remove.tree(No))]
nodes.dt <- tree.matrix[,.(Quality = sum(Quality)),by = .(abs.node.position, Feature)][,.(Text =paste0(Feature[1:min(length(Feature), features_keep)], " (", Quality[1:min(length(Quality), features_keep)], ")") %>% paste0(collapse = "\n")), by=abs.node.position] nodes.dt <- tree.matrix[
edges.dt <- tree.matrix[Feature != "Leaf",.(abs.node.position, Yes)] %>% list(tree.matrix[Feature != "Leaf",.(abs.node.position, No)]) %>% rbindlist() %>% setnames(c("From", "To")) %>% .[,.N,.(From, To)] %>% .[,N:=NULL] , .(Quality = sum(Quality))
, by = .(abs.node.position, Feature)
][, .(Text = paste0(Feature[1:min(length(Feature), features_keep)],
" (",
format(Quality[1:min(length(Quality), features_keep)], digits=5),
")") %>%
paste0(collapse = "\n"))
, by = abs.node.position]
nodes <- DiagrammeR::create_node_df(n = nrow(nodes.dt), edges.dt <- tree.matrix[Feature != "Leaf", .(abs.node.position, Yes)] %>%
label = nodes.dt[,Text], list(tree.matrix[Feature != "Leaf",.(abs.node.position, No)]) %>%
style = "filled", rbindlist() %>%
color = "DimGray", setnames(c("From", "To")) %>%
fillcolor= "Beige", .[, .N, .(From, To)] %>%
shape = "oval", .[, N:=NULL]
fontname = "Helvetica"
nodes <- DiagrammeR::create_node_df(
n = nrow(nodes.dt),
label = nodes.dt[,Text]
) )
edges <- DiagrammeR::create_edge_df(from = match(edges.dt[,From], nodes.dt[,abs.node.position]), edges <- DiagrammeR::create_edge_df(
to = match(edges.dt[,To], nodes.dt[,abs.node.position]), from = match(edges.dt[,From], nodes.dt[,abs.node.position]),
color = "DimGray", to = match(edges.dt[,To], nodes.dt[,abs.node.position]),
arrowsize = "1.5", rel = "leading_to")
arrowhead = "vee",
fontname = "Helvetica",
rel = "leading_to")
graph <- DiagrammeR::create_graph(nodes_df = nodes, graph <- DiagrammeR::create_graph(
edges_df = edges) nodes_df = nodes,
edges_df = edges,
attr_theme = NULL
) %>%
DiagrammeR::add_global_graph_attrs(
attr_type = "graph",
attr = c("layout", "rankdir"),
value = c("dot", "LR")
) %>%
DiagrammeR::add_global_graph_attrs(
attr_type = "node",
attr = c("color", "fillcolor", "style", "shape", "fontname"),
value = c("DimGray", "beige", "filled", "rectangle", "Helvetica")
) %>%
DiagrammeR::add_global_graph_attrs(
attr_type = "edge",
attr = c("color", "arrowsize", "arrowhead", "fontname"),
value = c("DimGray", "1.5", "vee", "Helvetica"))
DiagrammeR::render_graph(graph, width = plot_width, height = plot_height) DiagrammeR::render_graph(graph, width = plot_width, height = plot_height)
} }
globalVariables( globalVariables(c(".N", "N", "From", "To", "Text", "Feature", "no.nodes.abs.pos",
c( "ID", "Yes", "No", "Tree", "yes.nodes.abs.pos", "abs.node.position"))
".N", "N", "From", "To", "Text", "Feature", "no.nodes.abs.pos", "ID", "Yes", "No", "Tree", "yes.nodes.abs.pos", "abs.node.position"
)
)

View File

@ -2,37 +2,65 @@
#' #'
#' Read a tree model text dump and plot the model. #' Read a tree model text dump and plot the model.
#' #'
#' @param feature_names names of each feature as a \code{character} vector. Can be extracted from a sparse matrix (see example). If model dump already contains feature names, this argument should be \code{NULL}. #' @param feature_names names of each feature as a \code{character} vector.
#' @param model generated by the \code{xgb.train} function. Avoid the creation of a dump file. #' @param model produced by the \code{xgb.train} function.
#' @param n_first_tree limit the plot to the n first trees. If \code{NULL}, all trees of the model are plotted. Performance can be low for huge models. #' @param trees an integer vector of tree indices that should be visualized.
#' If set to \code{NULL}, all trees of the model are included.
#' IMPORTANT: the tree index in xgboost model is zero-based
#' (e.g., use \code{trees = 0:2} for the first 3 trees in a model).
#' @param plot_width the width of the diagram in pixels. #' @param plot_width the width of the diagram in pixels.
#' @param plot_height the height of the diagram in pixels. #' @param plot_height the height of the diagram in pixels.
#' @param render a logical flag for whether the graph should be rendered (see Value).
#' @param show_node_id a logical flag for whether to include node id's in the graph.
#' @param ... currently not used. #' @param ... currently not used.
#' #'
#' @return A \code{DiagrammeR} of the model.
#'
#' @details #' @details
#' #'
#' The content of each node is organised that way: #' The content of each node is organised that way:
#' #'
#' \itemize{ #' \itemize{
#' \item \code{feature} value; #' \item Feature name.
#' \item \code{cover}: the sum of second order gradient of training data classified to the leaf, if it is square loss, this simply corresponds to the number of instances in that branch. Deeper in the tree a node is, lower this metric will be; #' \item \code{Cover}: The sum of second order gradient of training data classified to the leaf.
#' \item \code{gain}: metric the importance of the node in the model. #' If it is square loss, this simply corresponds to the number of instances seen by a split
#' or collected by a leaf during training.
#' The deeper in the tree a node is, the lower this metric will be.
#' \item \code{Gain} (for split nodes): the information gain metric of a split
#' (corresponds to the importance of the node in the model).
#' \item \code{Value} (for leafs): the margin value that the leaf may contribute to prediction.
#' } #' }
#' The tree root nodes also indicate the Tree index (0-based).
#' #'
#' The function uses \href{http://www.graphviz.org/}{GraphViz} library for that purpose. #' The "Yes" branches are marked by the "< split_value" label.
#' The branches that also used for missing values are marked as bold
#' (as in "carrying extra capacity").
#'
#' This function uses \href{http://www.graphviz.org/}{GraphViz} as a backend of DiagrammeR.
#'
#' @return
#'
#' When \code{render = TRUE}:
#' returns a rendered graph object which is an \code{htmlwidget} of class \code{grViz}.
#' Similar to ggplot objects, it needs to be printed to see it when not running from command line.
#'
#' When \code{render = FALSE}:
#' silently returns a graph object which is of DiagrammeR's class \code{dgr_graph}.
#' This could be useful if one wants to modify some of the graph attributes
#' before rendering the graph with \code{\link[DiagrammeR]{render_graph}}.
#' #'
#' @examples #' @examples
#' data(agaricus.train, package='xgboost') #' data(agaricus.train, package='xgboost')
#' #'
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 2, #' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 3,
#' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic") #' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
#' #' # plot all the trees
#' xgb.plot.tree(feature_names = colnames(agaricus.train$data), model = bst) #' xgb.plot.tree(feature_names = colnames(agaricus.train$data), model = bst)
#' # plot only the first tree and include the node ID:
#' xgb.plot.tree(feature_names = colnames(agaricus.train$data), model = bst,
#' trees = 0, show_node_id = TRUE)
#' #'
#' @export #' @export
xgb.plot.tree <- function(feature_names = NULL, model = NULL, n_first_tree = NULL, plot_width = NULL, plot_height = NULL, ...){ xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot_width = NULL, plot_height = NULL,
render = TRUE, show_node_id = FALSE, ...){
check.deprecation(...) check.deprecation(...)
if (class(model) != "xgb.Booster") { if (class(model) != "xgb.Booster") {
stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.") stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.")
@ -42,34 +70,55 @@ xgb.plot.tree <- function(feature_names = NULL, model = NULL, n_first_tree = NUL
stop("DiagrammeR package is required for xgb.plot.tree", call. = FALSE) stop("DiagrammeR package is required for xgb.plot.tree", call. = FALSE)
} }
allTrees <- xgb.model.dt.tree(feature_names = feature_names, model = model, n_first_tree = n_first_tree) dt <- xgb.model.dt.tree(feature_names = feature_names, model = model, trees = trees)
allTrees[, label:= paste0(Feature, "\nCover: ", Cover, "\nGain: ", Quality)] dt[, label:= paste0(Feature, "\nCover: ", Cover, ifelse(Feature == "Leaf", "\nValue: ", "\nGain: "), Quality)]
allTrees[, shape:= "rectangle"][Feature == "Leaf", shape:= "oval"] if (show_node_id)
allTrees[, filledcolor:= "Beige"][Feature == "Leaf", filledcolor:= "Khaki"] dt[, label := paste0(ID, ": ", label)]
dt[Node == 0, label := paste0("Tree ", Tree, "\n", label)]
dt[, shape:= "rectangle"][Feature == "Leaf", shape:= "oval"]
dt[, filledcolor:= "Beige"][Feature == "Leaf", filledcolor:= "Khaki"]
# in order to draw the first tree on top:
dt <- dt[order(-Tree)]
# rev is used to put the first tree on top. nodes <- DiagrammeR::create_node_df(
nodes <- DiagrammeR::create_node_df(n = length(allTrees[,ID] %>% rev), n = nrow(dt),
label = allTrees[,label] %>% rev, ID = dt$ID,
style = "filled", label = dt$label,
color = "DimGray", fillcolor = dt$filledcolor,
fillcolor= allTrees[,filledcolor] %>% rev, shape = dt$shape,
shape = allTrees[,shape] %>% rev, data = dt$Feature)
data = allTrees[,Feature] %>% rev,
fontname = "Helvetica"
)
edges <- DiagrammeR::create_edge_df(from = match(allTrees[Feature != "Leaf", c(ID)] %>% rep(2), allTrees[,ID] %>% rev), edges <- DiagrammeR::create_edge_df(
to = match(allTrees[Feature != "Leaf", c(Yes, No)],allTrees[,ID] %>% rev), from = match(dt[Feature != "Leaf", c(ID)] %>% rep(2), dt$ID),
label = allTrees[Feature != "Leaf", paste("<",Split)] %>% c(rep("",nrow(allTrees[Feature != "Leaf"]))), to = match(dt[Feature != "Leaf", c(Yes, No)], dt$ID),
color = "DimGray", label = dt[Feature != "Leaf", paste("<", Split)] %>%
arrowsize = "1.5", c(rep("", nrow(dt[Feature != "Leaf"]))),
arrowhead = "vee", style = dt[Feature != "Leaf", ifelse(Missing == Yes, "bold", "solid")] %>%
fontname = "Helvetica", c(dt[Feature != "Leaf", ifelse(Missing == No, "bold", "solid")]),
rel = "leading_to") rel = "leading_to")
graph <- DiagrammeR::create_graph(nodes_df = nodes, graph <- DiagrammeR::create_graph(
edges_df = edges) nodes_df = nodes,
edges_df = edges,
attr_theme = NULL
) %>%
DiagrammeR::add_global_graph_attrs(
attr_type = "graph",
attr = c("layout", "rankdir"),
value = c("dot", "LR")
) %>%
DiagrammeR::add_global_graph_attrs(
attr_type = "node",
attr = c("color", "style", "fontname"),
value = c("DimGray", "filled", "Helvetica")
) %>%
DiagrammeR::add_global_graph_attrs(
attr_type = "edge",
attr = c("color", "arrowsize", "arrowhead", "fontname"),
value = c("DimGray", "1.5", "vee", "Helvetica"))
if (!render) return(invisible(graph))
DiagrammeR::render_graph(graph, width = plot_width, height = plot_height) DiagrammeR::render_graph(graph, width = plot_width, height = plot_height)
} }

View File

@ -5,7 +5,7 @@
\title{Parse a boosted tree model text dump} \title{Parse a boosted tree model text dump}
\usage{ \usage{
xgb.model.dt.tree(feature_names = NULL, model = NULL, text = NULL, xgb.model.dt.tree(feature_names = NULL, model = NULL, text = NULL,
n_first_tree = NULL) trees = NULL, ...)
} }
\arguments{ \arguments{
\item{feature_names}{character vector of feature names. If the model already \item{feature_names}{character vector of feature names. If the model already
@ -16,8 +16,13 @@ contains feature names, this argument should be \code{NULL} (default value)}
\item{text}{\code{character} vector previously generated by the \code{xgb.dump} \item{text}{\code{character} vector previously generated by the \code{xgb.dump}
function (where parameter \code{with_stats = TRUE} should have been set).} function (where parameter \code{with_stats = TRUE} should have been set).}
\item{n_first_tree}{limit the parsing to the \code{n} first trees. \item{trees}{an integer vector of tree indices that should be parsed.
If set to \code{NULL}, all trees of the model are parsed.} If set to \code{NULL}, all trees of the model are parsed.
It could be useful, e.g., in multiclass classification to get only
the trees of one certain class. IMPORTANT: the tree index in xgboost model
is zero-based (e.g., use \code{trees = 0:4} for first 5 trees).}
\item{...}{currently not used.}
} }
\value{ \value{
A \code{data.table} with detailed information about model trees' nodes. A \code{data.table} with detailed information about model trees' nodes.
@ -25,9 +30,9 @@ A \code{data.table} with detailed information about model trees' nodes.
The columns of the \code{data.table} are: The columns of the \code{data.table} are:
\itemize{ \itemize{
\item \code{Tree}: ID of a tree in a model \item \code{Tree}: ID of a tree in a model (integer)
\item \code{Node}: ID of a node in a tree \item \code{Node}: integer ID of a node in a tree (integer)
\item \code{ID}: unique identifier of a node in a model \item \code{ID}: identifier of a node in a model (character)
\item \code{Feature}: for a branch node, it's a feature id or name (when available); \item \code{Feature}: for a branch node, it's a feature id or name (when available);
for a leaf note, it simply labels it as \code{'Leaf'} for a leaf note, it simply labels it as \code{'Leaf'}
\item \code{Split}: location of the split for a branch node (split condition is always "less than") \item \code{Split}: location of the split for a branch node (split condition is always "less than")

View File

@ -8,9 +8,9 @@ xgb.plot.multi.trees(model, feature_names = NULL, features_keep = 5,
plot_width = NULL, plot_height = NULL, ...) plot_width = NULL, plot_height = NULL, ...)
} }
\arguments{ \arguments{
\item{model}{dump generated by the \code{xgb.train} function.} \item{model}{produced by the \code{xgb.train} function.}
\item{feature_names}{names of each feature as a \code{character} vector. Can be extracted from a sparse matrix (see example). If model dump already contains feature names, this argument should be \code{NULL}.} \item{feature_names}{names of each feature as a \code{character} vector.}
\item{features_keep}{number of features to keep in each position of the multi trees.} \item{features_keep}{number of features to keep in each position of the multi trees.}
@ -27,21 +27,19 @@ Two graphs showing the distribution of the model deepness.
Visualization of the ensemble of trees as a single collective unit. Visualization of the ensemble of trees as a single collective unit.
} }
\details{ \details{
This function tries to capture the complexity of gradient boosted tree ensemble This function tries to capture the complexity of a gradient boosted tree model
in a cohesive way. in a cohesive way by compressing an ensemble of trees into a single tree-graph representation.
The goal is to improve the interpretability of a model generally seen as black box.
The goal is to improve the interpretability of the model generally seen as black box. Note: this function is applicable to tree booster-based models only.
The function is dedicated to boosting applied to decision trees only.
The purpose is to move from an ensemble of trees to a single tree only.
It takes advantage of the fact that the shape of a binary tree is only defined by It takes advantage of the fact that the shape of a binary tree is only defined by
its deepness (therefore in a boosting model, all trees have the same shape). its depth (therefore, in a boosting model, all trees have similar shape).
Moreover, the trees tend to reuse the same features. Moreover, the trees tend to reuse the same features.
The function will project each tree on one, and keep for each position the The function projects each tree onto one, and keeps for each position the
\code{features_keep} first features (based on Gain per feature measure). \code{features_keep} first features (based on the Gain per feature measure).
This function is inspired by this blog post: This function is inspired by this blog post:
\url{https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/} \url{https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/}

View File

@ -4,24 +4,39 @@
\alias{xgb.plot.tree} \alias{xgb.plot.tree}
\title{Plot a boosted tree model} \title{Plot a boosted tree model}
\usage{ \usage{
xgb.plot.tree(feature_names = NULL, model = NULL, n_first_tree = NULL, xgb.plot.tree(feature_names = NULL, model = NULL, trees = NULL,
plot_width = NULL, plot_height = NULL, ...) plot_width = NULL, plot_height = NULL, render = TRUE,
show_node_id = FALSE, ...)
} }
\arguments{ \arguments{
\item{feature_names}{names of each feature as a \code{character} vector. Can be extracted from a sparse matrix (see example). If model dump already contains feature names, this argument should be \code{NULL}.} \item{feature_names}{names of each feature as a \code{character} vector.}
\item{model}{generated by the \code{xgb.train} function. Avoid the creation of a dump file.} \item{model}{produced by the \code{xgb.train} function.}
\item{n_first_tree}{limit the plot to the n first trees. If \code{NULL}, all trees of the model are plotted. Performance can be low for huge models.} \item{trees}{an integer vector of tree indices that should be visualized.
If set to \code{NULL}, all trees of the model are included.
IMPORTANT: the tree index in xgboost model is zero-based
(e.g., use \code{trees = 0:2} for the first 3 trees in a model).}
\item{plot_width}{the width of the diagram in pixels.} \item{plot_width}{the width of the diagram in pixels.}
\item{plot_height}{the height of the diagram in pixels.} \item{plot_height}{the height of the diagram in pixels.}
\item{render}{a logical flag for whether the graph should be rendered (see Value).}
\item{show_node_id}{a logical flag for whether to include node id's in the graph.}
\item{...}{currently not used.} \item{...}{currently not used.}
} }
\value{ \value{
A \code{DiagrammeR} of the model. When \code{render = TRUE}:
returns a rendered graph object which is an \code{htmlwidget} of class \code{grViz}.
Similar to ggplot objects, it needs to be printed to see it when not running from command line.
When \code{render = FALSE}:
silently returns a graph object which is of DiagrammeR's class \code{dgr_graph}.
This could be useful if one wants to modify some of the graph attributes
before rendering the graph with \code{\link[DiagrammeR]{render_graph}}.
} }
\description{ \description{
Read a tree model text dump and plot the model. Read a tree model text dump and plot the model.
@ -30,20 +45,33 @@ Read a tree model text dump and plot the model.
The content of each node is organised that way: The content of each node is organised that way:
\itemize{ \itemize{
\item \code{feature} value; \item Feature name.
\item \code{cover}: the sum of second order gradient of training data classified to the leaf, if it is square loss, this simply corresponds to the number of instances in that branch. Deeper in the tree a node is, lower this metric will be; \item \code{Cover}: The sum of second order gradient of training data classified to the leaf.
\item \code{gain}: metric the importance of the node in the model. If it is square loss, this simply corresponds to the number of instances seen by a split
or collected by a leaf during training.
The deeper in the tree a node is, the lower this metric will be.
\item \code{Gain} (for split nodes): the information gain metric of a split
(corresponds to the importance of the node in the model).
\item \code{Value} (for leafs): the margin value that the leaf may contribute to prediction.
} }
The tree root nodes also indicate the Tree index (0-based).
The function uses \href{http://www.graphviz.org/}{GraphViz} library for that purpose. The "Yes" branches are marked by the "< split_value" label.
The branches that also used for missing values are marked as bold
(as in "carrying extra capacity").
This function uses \href{http://www.graphviz.org/}{GraphViz} as a backend of DiagrammeR.
} }
\examples{ \examples{
data(agaricus.train, package='xgboost') data(agaricus.train, package='xgboost')
bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 2, bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 3,
eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic") eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
# plot all the trees
xgb.plot.tree(feature_names = colnames(agaricus.train$data), model = bst) xgb.plot.tree(feature_names = colnames(agaricus.train$data), model = bst)
# plot only the first tree and include the node ID:
xgb.plot.tree(feature_names = colnames(agaricus.train$data), model = bst,
trees = 0, show_node_id = TRUE)
} }