85 lines
3.9 KiB
R
85 lines
3.9 KiB
R
#' Plot a boosted tree model
|
|
#'
|
|
#' Read a tree model text dump and plot the model.
|
|
#'
|
|
#' @importFrom data.table data.table
|
|
#' @importFrom data.table :=
|
|
#' @importFrom magrittr %>%
|
|
#' @param feature_names names of each feature as a \code{character} vector. Can be extracted from a sparse matrix (see example). If model dump already contains feature names, this argument should be \code{NULL}.
|
|
#' @param model generated by the \code{xgb.train} function. Avoid the creation of a dump file.
|
|
#' @param n_first_tree limit the plot to the n first trees. If \code{NULL}, all trees of the model are plotted. Performance can be low for huge models.
|
|
#' @param plot.width the width of the diagram in pixels.
|
|
#' @param plot.height the height of the diagram in pixels.
|
|
#'
|
|
#' @return A \code{DiagrammeR} of the model.
|
|
#'
|
|
#' @details
|
|
#'
|
|
#' The content of each node is organised that way:
|
|
#'
|
|
#' \itemize{
|
|
#' \item \code{feature} value;
|
|
#' \item \code{cover}: the sum of second order gradient of training data classified to the leaf, if it is square loss, this simply corresponds to the number of instances in that branch. Deeper in the tree a node is, lower this metric will be;
|
|
#' \item \code{gain}: metric the importance of the node in the model.
|
|
#' }
|
|
#'
|
|
#' The function uses \href{http://www.graphviz.org/}{GraphViz} library for that purpose.
|
|
#'
|
|
#' @examples
|
|
#' data(agaricus.train, package='xgboost')
|
|
#'
|
|
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max.depth = 2,
|
|
#' eta = 1, nthread = 2, nround = 2,objective = "binary:logistic")
|
|
#'
|
|
#' # agaricus.train$data@@Dimnames[[2]] represents the column names of the sparse matrix.
|
|
#' xgb.plot.tree(feature_names = agaricus.train$data@@Dimnames[[2]], model = bst)
|
|
#'
|
|
#' @export
|
|
xgb.plot.tree <- function(feature_names = NULL, model = NULL, n_first_tree = NULL, plot.width = NULL, plot.height = NULL){
|
|
|
|
if (class(model) != "xgb.Booster") {
|
|
stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.")
|
|
}
|
|
|
|
if (!requireNamespace("DiagrammeR", quietly = TRUE)) {
|
|
stop("DiagrammeR package is required for xgb.plot.tree", call. = FALSE)
|
|
}
|
|
|
|
allTrees <- xgb.model.dt.tree(feature_names = feature_names, model = model, n_first_tree = n_first_tree)
|
|
|
|
allTrees[, label:= paste0(Feature, "\nCover: ", Cover, "\nGain: ", Quality)]
|
|
allTrees[, shape:= "rectangle"][Feature == "Leaf", shape:= "oval"]
|
|
allTrees[, filledcolor:= "Beige"][Feature == "Leaf", filledcolor:= "Khaki"]
|
|
|
|
# rev is used to put the first tree on top.
|
|
nodes <- DiagrammeR::create_nodes(nodes = allTrees[,ID] %>% rev,
|
|
label = allTrees[,label] %>% rev,
|
|
style = "filled",
|
|
color = "DimGray",
|
|
fillcolor= allTrees[,filledcolor] %>% rev,
|
|
shape = allTrees[,shape] %>% rev,
|
|
data = allTrees[,Feature] %>% rev,
|
|
fontname = "Helvetica"
|
|
)
|
|
|
|
edges <- DiagrammeR::create_edges(from = allTrees[Feature != "Leaf", c(ID)] %>% rep(2),
|
|
to = allTrees[Feature != "Leaf", c(Yes, No)],
|
|
label = allTrees[Feature != "Leaf", paste("<",Split)] %>% c(rep("",nrow(allTrees[Feature != "Leaf"]))),
|
|
color = "DimGray",
|
|
arrowsize = "1.5",
|
|
arrowhead = "vee",
|
|
fontname = "Helvetica",
|
|
rel = "leading_to")
|
|
|
|
graph <- DiagrammeR::create_graph(nodes_df = nodes,
|
|
edges_df = edges,
|
|
graph_attrs = "rankdir = LR")
|
|
|
|
DiagrammeR::render_graph(graph, width = plot.width, height = plot.height)
|
|
}
|
|
|
|
# Avoid error messages during CRAN check.
|
|
# The reason is that these variables are never declared
|
|
# They are mainly column names inferred by Data.table...
|
|
globalVariables(c("Feature", "ID", "Cover", "Quality", "Split", "Yes", "No", ".", "shape", "filledcolor", "label"))
|