diff --git a/R-package/NAMESPACE b/R-package/NAMESPACE index 7f6fa5817..a9ae672a3 100644 --- a/R-package/NAMESPACE +++ b/R-package/NAMESPACE @@ -12,6 +12,7 @@ export(xgb.load) export(xgb.model.dt.tree) export(xgb.plot.deepness) export(xgb.plot.importance) +export(xgb.plot.multi.trees) export(xgb.plot.tree) export(xgb.save) export(xgb.save.raw) @@ -36,6 +37,7 @@ importFrom(data.table,setnames) importFrom(magrittr,"%>%") importFrom(magrittr,add) importFrom(magrittr,not) +importFrom(stringr,str_detect) importFrom(stringr,str_extract) importFrom(stringr,str_extract_all) importFrom(stringr,str_match) diff --git a/R-package/R/xgb.model.dt.tree.R b/R-package/R/xgb.model.dt.tree.R index 5833389e2..13d3ecc5b 100644 --- a/R-package/R/xgb.model.dt.tree.R +++ b/R-package/R/xgb.model.dt.tree.R @@ -12,7 +12,6 @@ #' @importFrom magrittr add #' @importFrom stringr str_extract #' @importFrom stringr str_split -#' @importFrom stringr str_extract #' @importFrom stringr str_trim #' @param feature_names names of each feature as a character vector. Can be extracted from a sparse matrix (see example). If model dump already contains feature names, this argument should be \code{NULL}. #' @param filename_dump the path to the text file storing the model. Model dump must include the gain per feature and per tree (parameter \code{with.stats = T} in function \code{xgb.dump}). diff --git a/R-package/R/xgb.plot.multi.trees.R b/R-package/R/xgb.plot.multi.trees.R new file mode 100644 index 000000000..f53d1a13f --- /dev/null +++ b/R-package/R/xgb.plot.multi.trees.R @@ -0,0 +1,112 @@ +#' Project all trees on one tree and plot it +#' +#' Visualization of the ensemble of trees as a single collective unit. +#' +#' @importFrom data.table data.table +#' @importFrom data.table rbindlist +#' @importFrom data.table setnames +#' @importFrom data.table := +#' @importFrom magrittr %>% +#' @importFrom stringr str_detect +#' @importFrom stringr str_extract +#' +#' @param filename_dump the path to the text file storing the model. Model dump must include the gain per feature and per tree (parameter \code{with.stats = T} in function \code{xgb.dump}). +#' @param model dump generated by the \code{xgb.train} function. Avoid the creation of a dump file. +#' @param features.keep number of features to keep in each position of the multi tree. +#' @param plot.width width in pixels of the graph to produce +#' @param plot.height height in pixels of the graph to produce +#' +#' @return Two graphs showing the distribution of the model deepness. +#' +#' @details +#' +#' This function tries to capture the complexity of gradient boosted tree ensembles +#' in a cohesive way. +#' The goal is to improve the interpretability of the model generally seen as black box. +#' The function is dedicated to boosting applied to decision trees only. +#' +#' The purpose is to move from an ensemble of trees to a single tree only. +#' It takes advantage of the fact that the shape of a binary tree is only defined by +#' its deepness. +#' Therefore in a boosting model, all trees have the same shape. +#' Moreover, the trees tend to reuse the same features. +#' +#' The function will project each trees on one, and keep for each position the +#' \code{features.keep} first features (based on Gain per feature). +#' +#' This function is inspired from this blog post: +#' \url{https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/} +#' +#' @examples +#' data(agaricus.train, package='xgboost') +#' +#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max.depth = 15, +#' eta = 1, nthread = 2, nround = 30, objective = "binary:logistic", +#' min_child_weight = 50) +#' +#' p <- xgb.plot.multi.trees(bst, agaricus.train$data@Dimnames[[2]], 3) +#' print(p) +#' +#' @export +xgb.plot.multi.trees <- function(model, names, features.keep = 5, plot.width = NULL, plot.height = NULL){ + tree.matrix <- xgb.model.dt.tree(names, model = model) + + # first number of the path represents the tree, then the following numbers are related to the path to follow + # root init + root.nodes <- tree.matrix[str_detect(ID, "\\d+-0"), ID] + tree.matrix[ID %in% root.nodes, abs.node.position:=root.nodes] + + precedent.nodes <- root.nodes + + while(tree.matrix[,sum(is.na(abs.node.position))] > 0) { + yes.row.nodes <- tree.matrix[abs.node.position %in% precedent.nodes & !is.na(Yes)] + no.row.nodes <- tree.matrix[abs.node.position %in% precedent.nodes & !is.na(No)] + yes.nodes.abs.pos <- yes.row.nodes[, abs.node.position] %>% paste0("_0") + no.nodes.abs.pos <- no.row.nodes[, abs.node.position] %>% paste0("_1") + + tree.matrix[ID %in% yes.row.nodes[, Yes], abs.node.position := yes.nodes.abs.pos] + tree.matrix[ID %in% no.row.nodes[, No], abs.node.position := no.nodes.abs.pos] + precedent.nodes <- c(yes.nodes.abs.pos, no.nodes.abs.pos) + } + + tree.matrix[!is.na(Yes),Yes:= paste0(abs.node.position, "_0")] + tree.matrix[!is.na(No),No:= paste0(abs.node.position, "_1")] + + + + remove.tree <- . %>% str_replace(pattern = "^\\d+-", replacement = "") + + tree.matrix[,`:=`(abs.node.position=remove.tree(abs.node.position), Yes=remove.tree(Yes), No=remove.tree(No))] + + nodes.dt <- tree.matrix[,.(Quality = sum(Quality)),by = .(abs.node.position, Feature)][,.(Text =paste0(Feature[1:min(length(Feature), features.keep)], " (", Quality[1:min(length(Quality), features.keep)], ")") %>% paste0(collapse = "\n")), by=abs.node.position] + edges.dt <- tree.matrix[Feature != "Leaf",.(abs.node.position, Yes)] %>% list(tree.matrix[Feature != "Leaf",.(abs.node.position, No)]) %>% rbindlist() %>% setnames(c("From", "To")) %>% .[,.N,.(From, To)] %>% .[,N:=NULL] + + nodes <- DiagrammeR::create_nodes(nodes = nodes.dt[,abs.node.position], + label = nodes.dt[,Text], + style = "filled", + color = "DimGray", + fillcolor= "Beige", + shape = "oval", + fontname = "Helvetica" + ) + + edges <- DiagrammeR::create_edges(from = edges.dt[,From], + to = edges.dt[,To], + color = "DimGray", + arrowsize = "1.5", + arrowhead = "vee", + fontname = "Helvetica", + rel = "leading_to") + + graph <- DiagrammeR::create_graph(nodes_df = nodes, + edges_df = edges, + graph_attrs = "rankdir = LR") + + DiagrammeR::render_graph(graph, width = plot.width, height = plot.height) +} + +globalVariables( + c( + "Feature", "no.nodes.abs.pos", "ID", "Yes", "No", "Tree", "yes.nodes.abs.pos", "abs.node.position" + ) +) \ No newline at end of file diff --git a/R-package/R/xgb.plot.tree.R b/R-package/R/xgb.plot.tree.R index 63bebf6cf..2976f1b07 100644 --- a/R-package/R/xgb.plot.tree.R +++ b/R-package/R/xgb.plot.tree.R @@ -10,8 +10,8 @@ #' @param filename_dump the path to the text file storing the model. Model dump must include the gain per feature and per tree (parameter \code{with.stats = T} in function \code{xgb.dump}). Possible to provide a model directly (see \code{model} argument). #' @param model generated by the \code{xgb.train} function. Avoid the creation of a dump file. #' @param n_first_tree limit the plot to the n first trees. If \code{NULL}, all trees of the model are plotted. Performance can be low for huge models. -#' @param width the width of the diagram in pixels. -#' @param height the height of the diagram in pixels. +#' @param plot.width the width of the diagram in pixels. +#' @param plot.height the height of the diagram in pixels. #' #' @return A \code{DiagrammeR} of the model. #' @@ -43,7 +43,7 @@ #' xgb.plot.tree(agaricus.train$data@@Dimnames[[2]], model = bst) #' #' @export -xgb.plot.tree <- function(feature_names = NULL, filename_dump = NULL, model = NULL, n_first_tree = NULL, width = NULL, height = NULL){ +xgb.plot.tree <- function(feature_names = NULL, filename_dump = NULL, model = NULL, n_first_tree = NULL, plot.width = NULL, plot.height = NULL){ if (!class(model) %in% c("xgb.Booster", "NULL")) { stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.") @@ -87,7 +87,7 @@ xgb.plot.tree <- function(feature_names = NULL, filename_dump = NULL, model = NU edges_df = edges, graph_attrs = "rankdir = LR") - DiagrammeR::render_graph(graph, width = width, height = height) + DiagrammeR::render_graph(graph, width = plot.width, height = plot.height) } # Avoid error messages during CRAN check. diff --git a/R-package/man/xgb.plot.multi.trees.Rd b/R-package/man/xgb.plot.multi.trees.Rd new file mode 100644 index 000000000..2bbe29ca5 --- /dev/null +++ b/R-package/man/xgb.plot.multi.trees.Rd @@ -0,0 +1,56 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/xgb.plot.multi.trees.R +\name{xgb.plot.multi.trees} +\alias{xgb.plot.multi.trees} +\title{Project all trees on one tree and plot it} +\usage{ +xgb.plot.multi.trees(model, names, features.keep = 5, plot.width = NULL, + plot.height = NULL) +} +\arguments{ +\item{model}{dump generated by the \code{xgb.train} function. Avoid the creation of a dump file.} + +\item{features.keep}{number of features to keep in each position of the multi tree.} + +\item{plot.width}{width in pixels of the graph to produce} + +\item{plot.height}{height in pixels of the graph to produce} + +\item{filename_dump}{the path to the text file storing the model. Model dump must include the gain per feature and per tree (parameter \code{with.stats = T} in function \code{xgb.dump}).} +} +\value{ +Two graphs showing the distribution of the model deepness. +} +\description{ +Visualization of the ensemble of trees as a single collective unit. +} +\details{ +This function tries to capture the complexity of gradient boosted tree ensembles +in a cohesive way. +The goal is to improve the interpretability of the model generally seen as black box. +The function is dedicated to boosting applied to decision trees only. + +The purpose is to move from an ensemble of trees to a single tree only. +It takes advantage of the fact that the shape of a binary tree is only defined by +its deepness. +Therefore in a boosting model, all trees have the same shape. +Moreover, the trees tend to reuse the same features. + +The function will project each trees on one, and keep for each position the +\code{features.keep} first features (based on Gain per feature). + +This function is inspired from this blog post: +\url{https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/} +} +\examples{ +data(agaricus.train, package='xgboost') + +bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max.depth = 15, + eta = 1, nthread = 2, nround = 30, objective = "binary:logistic", + min_child_weight = 50) + +p <- xgb.plot.multi.trees(bst, agaricus.train$data@Dimnames[[2]], 3) +print(p) + +} + diff --git a/R-package/man/xgb.plot.tree.Rd b/R-package/man/xgb.plot.tree.Rd index f34e75bf9..2008014cf 100644 --- a/R-package/man/xgb.plot.tree.Rd +++ b/R-package/man/xgb.plot.tree.Rd @@ -5,7 +5,7 @@ \title{Plot a boosted tree model} \usage{ xgb.plot.tree(feature_names = NULL, filename_dump = NULL, model = NULL, - n_first_tree = NULL, width = NULL, height = NULL) + n_first_tree = NULL, plot.width = NULL, plot.height = NULL) } \arguments{ \item{feature_names}{names of each feature as a character vector. Can be extracted from a sparse matrix (see example). If model dump already contains feature names, this argument should be \code{NULL}.} @@ -16,9 +16,9 @@ xgb.plot.tree(feature_names = NULL, filename_dump = NULL, model = NULL, \item{n_first_tree}{limit the plot to the n first trees. If \code{NULL}, all trees of the model are plotted. Performance can be low for huge models.} -\item{width}{the width of the diagram in pixels.} +\item{plot.width}{the width of the diagram in pixels.} -\item{height}{the height of the diagram in pixels.} +\item{plot.height}{the height of the diagram in pixels.} } \value{ A \code{DiagrammeR} of the model.