#' Project all trees on one tree and plot it #' #' Visualization of the ensemble of trees as a single collective unit. #' #' @importFrom data.table data.table #' @importFrom data.table rbindlist #' @importFrom data.table setnames #' @importFrom data.table := #' @importFrom magrittr %>% #' @importFrom stringr str_detect #' @importFrom stringr str_extract #' #' @param model dump generated by the \code{xgb.train} function. #' @param feature_names names of each feature as a \code{character} vector. Can be extracted from a sparse matrix (see example). If model dump already contains feature names, this argument should be \code{NULL}. #' @param features.keep number of features to keep in each position of the multi trees. #' @param plot.width width in pixels of the graph to produce #' @param plot.height height in pixels of the graph to produce #' #' @return Two graphs showing the distribution of the model deepness. #' #' @details #' #' This function tries to capture the complexity of gradient boosted tree ensemble #' in a cohesive way. #' #' The goal is to improve the interpretability of the model generally seen as black box. #' The function is dedicated to boosting applied to decision trees only. #' #' The purpose is to move from an ensemble of trees to a single tree only. #' #' It takes advantage of the fact that the shape of a binary tree is only defined by #' its deepness (therefore in a boosting model, all trees have the same shape). #' #' Moreover, the trees tend to reuse the same features. #' #' The function will project each tree on one, and keep for each position the #' \code{features.keep} first features (based on Gain per feature measure). #' #' This function is inspired by this blog post: #' \url{https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/} #' #' @examples #' data(agaricus.train, package='xgboost') #' #' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max.depth = 15, #' eta = 1, nthread = 2, nround = 30, objective = "binary:logistic", #' min_child_weight = 50) #' #' p <- xgb.plot.multi.trees(model = bst, feature_names = agaricus.train$data@Dimnames[[2]], features.keep = 3) #' print(p) #' #' @export xgb.plot.multi.trees <- function(model, feature_names = NULL, features.keep = 5, plot.width = NULL, plot.height = NULL){ tree.matrix <- xgb.model.dt.tree(feature_names = feature_names, model = model) # first number of the path represents the tree, then the following numbers are related to the path to follow # root init root.nodes <- tree.matrix[str_detect(ID, "\\d+-0"), ID] tree.matrix[ID %in% root.nodes, abs.node.position:=root.nodes] precedent.nodes <- root.nodes while(tree.matrix[,sum(is.na(abs.node.position))] > 0) { yes.row.nodes <- tree.matrix[abs.node.position %in% precedent.nodes & !is.na(Yes)] no.row.nodes <- tree.matrix[abs.node.position %in% precedent.nodes & !is.na(No)] yes.nodes.abs.pos <- yes.row.nodes[, abs.node.position] %>% paste0("_0") no.nodes.abs.pos <- no.row.nodes[, abs.node.position] %>% paste0("_1") tree.matrix[ID %in% yes.row.nodes[, Yes], abs.node.position := yes.nodes.abs.pos] tree.matrix[ID %in% no.row.nodes[, No], abs.node.position := no.nodes.abs.pos] precedent.nodes <- c(yes.nodes.abs.pos, no.nodes.abs.pos) } tree.matrix[!is.na(Yes),Yes:= paste0(abs.node.position, "_0")] tree.matrix[!is.na(No),No:= paste0(abs.node.position, "_1")] remove.tree <- . %>% str_replace(pattern = "^\\d+-", replacement = "") tree.matrix[,`:=`(abs.node.position=remove.tree(abs.node.position), Yes=remove.tree(Yes), No=remove.tree(No))] nodes.dt <- tree.matrix[,.(Quality = sum(Quality)),by = .(abs.node.position, Feature)][,.(Text =paste0(Feature[1:min(length(Feature), features.keep)], " (", Quality[1:min(length(Quality), features.keep)], ")") %>% paste0(collapse = "\n")), by=abs.node.position] edges.dt <- tree.matrix[Feature != "Leaf",.(abs.node.position, Yes)] %>% list(tree.matrix[Feature != "Leaf",.(abs.node.position, No)]) %>% rbindlist() %>% setnames(c("From", "To")) %>% .[,.N,.(From, To)] %>% .[,N:=NULL] nodes <- DiagrammeR::create_nodes(nodes = nodes.dt[,abs.node.position], label = nodes.dt[,Text], style = "filled", color = "DimGray", fillcolor= "Beige", shape = "oval", fontname = "Helvetica" ) edges <- DiagrammeR::create_edges(from = edges.dt[,From], to = edges.dt[,To], color = "DimGray", arrowsize = "1.5", arrowhead = "vee", fontname = "Helvetica", rel = "leading_to") graph <- DiagrammeR::create_graph(nodes_df = nodes, edges_df = edges, graph_attrs = "rankdir = LR") DiagrammeR::render_graph(graph, width = plot.width, height = plot.height) } globalVariables( c( "Feature", "no.nodes.abs.pos", "ID", "Yes", "No", "Tree", "yes.nodes.abs.pos", "abs.node.position" ) )