From 5169d087353f23d90d5b3cf4439179ed6d03ff3e Mon Sep 17 00:00:00 2001 From: pommedeterresautee Date: Fri, 27 Nov 2015 14:49:06 +0100 Subject: [PATCH] Add new multi.tree function to R package --- R-package/R/xgb.plot.multi.trees.R | 100 +++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 R-package/R/xgb.plot.multi.trees.R diff --git a/R-package/R/xgb.plot.multi.trees.R b/R-package/R/xgb.plot.multi.trees.R new file mode 100644 index 000000000..037b66e70 --- /dev/null +++ b/R-package/R/xgb.plot.multi.trees.R @@ -0,0 +1,100 @@ +library(stringr) +library(data.table) +library(xgboost) + +#' Project all trees on one and plot it +#' +#' Provide a way to display on one tree all trees of the model. +#' +#' @importFrom data.table data.table +#' @importFrom data.table rbindlist +#' @importFrom data.table setnames +#' @importFrom data.table := +#' @importFrom magrittr %>% +#' @param filename_dump the path to the text file storing the model. Model dump must include the gain per feature and per tree (parameter \code{with.stats = T} in function \code{xgb.dump}). +#' @param model dump generated by the \code{xgb.train} function. Avoid the creation of a dump file. +#' +#' @return Two graphs showing the distribution of the model deepness. +#' +#' @details +#' +#' This function tries to capture the complexity of gradient boosted tree ensembles in a cohesive way. +#' The goal is to improve the interpretability of the model generally seen as black box. +#' The function is dedicated to boosting applied to trees only. It won't work on GLM. +#' +#' The purpose is to move from an ensemble of trees to a single tree only. +#' It leverages the fact that the shape of a binary tree is only defined by its deepness. +#' The second fact which is leverage is that all trees in a boosting model tend to share the features they use. +#' +#' The function will project each trees on one tree, and keep the \code{keepN} first feature for each position. +#' This function is inspired from this blog post: +#' \url{https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/} +#' +#' @examples +#' data(agaricus.train, package='xgboost') +#' +#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max.depth = 15, +#' eta = 1, nthread = 2, nround = 30, objective = "binary:logistic", +#' min_child_weight = 50) +#' +#' p <- xgb.plot.multi.trees(bst, agaricus.train$data@Dimnames[[2]], 3) +#' print(p) +#' +#' @export +xgb.plot.multi.trees <- function(model, names, keepN = 5, plot.width = NULL, plot.height = NULL){ + tree.matrix <- xgb.model.dt.tree(names, model = model) + + # first number of the path represents the tree, then the following numbers are related to the path to follow + # root init + root.nodes <- tree.matrix[str_detect(ID, "\\d+-0"), ID] + tree.matrix[ID %in% root.nodes, abs.node.position:=root.nodes] + + precedent.nodes <- root.nodes + + while(tree.matrix[,sum(is.na(abs.node.position))] > 0) { + yes.row.nodes <- tree.matrix[abs.node.position %in% precedent.nodes & !is.na(Yes)] + no.row.nodes <- tree.matrix[abs.node.position %in% precedent.nodes & !is.na(No)] + yes.nodes.abs.pos <- yes.row.nodes[, abs.node.position] %>% paste0("_0") + no.nodes.abs.pos <- no.row.nodes[, abs.node.position] %>% paste0("_1") + + tree.matrix[ID %in% yes.row.nodes[, Yes], abs.node.position := yes.nodes.abs.pos] + tree.matrix[ID %in% no.row.nodes[, No], abs.node.position := no.nodes.abs.pos] + precedent.nodes <- c(yes.nodes.abs.pos, no.nodes.abs.pos) + } + + tree.matrix[!is.na(Yes),Yes:= paste0(abs.node.position, "_0")] + tree.matrix[!is.na(No),No:= paste0(abs.node.position, "_1")] + + + + remove.tree <- . %>% str_replace(pattern = "^\\d+-", replacement = "") + + tree.matrix[,`:=`(abs.node.position=remove.tree(abs.node.position), Yes=remove.tree(Yes), No=remove.tree(No))] + + nodes.dt <- tree.matrix[,.(Quality = sum(Quality)),by = .(abs.node.position, Feature)][,.(Text =paste0(Feature[1:min(length(Feature), keepN)], " (", Quality[1:min(length(Quality), keepN)], ")") %>% paste0(collapse = "\n")), by=abs.node.position] + edges.dt <- tree.matrix[Feature != "Leaf",.(abs.node.position, Yes)] %>% list(tree.matrix[Feature != "Leaf",.(abs.node.position, No)]) %>% rbindlist() %>% setnames(c("From", "To")) %>% .[,.N,.(From, To)] %>% .[,N:=NULL] + + nodes <- DiagrammeR::create_nodes(nodes = nodes.dt[,abs.node.position], + label = nodes.dt[,Text], + style = "filled", + color = "DimGray", + fillcolor= "Blue", + shape = "oval", + #data = allTrees[,Feature] + fontname = "Helvetica" + ) + + edges <- DiagrammeR::create_edges(from = edges.dt[,From], + to = edges.dt[,To], + color = "DimGray", + arrowsize = "1.5", + arrowhead = "vee", + fontname = "Helvetica", + rel = "leading_to") + + graph <- DiagrammeR::create_graph(nodes_df = nodes, + edges_df = edges, + graph_attrs = "rankdir = LR") + + DiagrammeR::render_graph(graph, width = plot.width, height = plot.height) +}