Merge pull request #655 from pommedeterresautee/master

Add new multi tree plot function to R package
This commit is contained in:
Michaël Benesty 2015-11-28 18:08:27 +01:00
commit 09ed3f10cc
6 changed files with 177 additions and 8 deletions

View File

@ -12,6 +12,7 @@ export(xgb.load)
export(xgb.model.dt.tree)
export(xgb.plot.deepness)
export(xgb.plot.importance)
export(xgb.plot.multi.trees)
export(xgb.plot.tree)
export(xgb.save)
export(xgb.save.raw)
@ -36,6 +37,7 @@ importFrom(data.table,setnames)
importFrom(magrittr,"%>%")
importFrom(magrittr,add)
importFrom(magrittr,not)
importFrom(stringr,str_detect)
importFrom(stringr,str_extract)
importFrom(stringr,str_extract_all)
importFrom(stringr,str_match)

View File

@ -12,7 +12,6 @@
#' @importFrom magrittr add
#' @importFrom stringr str_extract
#' @importFrom stringr str_split
#' @importFrom stringr str_extract
#' @importFrom stringr str_trim
#' @param feature_names names of each feature as a character vector. Can be extracted from a sparse matrix (see example). If model dump already contains feature names, this argument should be \code{NULL}.
#' @param filename_dump the path to the text file storing the model. Model dump must include the gain per feature and per tree (parameter \code{with.stats = T} in function \code{xgb.dump}).

View File

@ -0,0 +1,112 @@
#' Project all trees on one tree and plot it
#'
#' Visualization of the ensemble of trees as a single collective unit.
#'
#' @importFrom data.table data.table
#' @importFrom data.table rbindlist
#' @importFrom data.table setnames
#' @importFrom data.table :=
#' @importFrom magrittr %>%
#' @importFrom stringr str_detect
#' @importFrom stringr str_extract
#'
#' @param filename_dump the path to the text file storing the model. Model dump must include the gain per feature and per tree (parameter \code{with.stats = T} in function \code{xgb.dump}).
#' @param model dump generated by the \code{xgb.train} function. Avoid the creation of a dump file.
#' @param features.keep number of features to keep in each position of the multi tree.
#' @param plot.width width in pixels of the graph to produce
#' @param plot.height height in pixels of the graph to produce
#'
#' @return Two graphs showing the distribution of the model deepness.
#'
#' @details
#'
#' This function tries to capture the complexity of gradient boosted tree ensembles
#' in a cohesive way.
#' The goal is to improve the interpretability of the model generally seen as black box.
#' The function is dedicated to boosting applied to decision trees only.
#'
#' The purpose is to move from an ensemble of trees to a single tree only.
#' It takes advantage of the fact that the shape of a binary tree is only defined by
#' its deepness.
#' Therefore in a boosting model, all trees have the same shape.
#' Moreover, the trees tend to reuse the same features.
#'
#' The function will project each trees on one, and keep for each position the
#' \code{features.keep} first features (based on Gain per feature).
#'
#' This function is inspired from this blog post:
#' \url{https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/}
#'
#' @examples
#' data(agaricus.train, package='xgboost')
#'
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max.depth = 15,
#' eta = 1, nthread = 2, nround = 30, objective = "binary:logistic",
#' min_child_weight = 50)
#'
#' p <- xgb.plot.multi.trees(bst, agaricus.train$data@Dimnames[[2]], 3)
#' print(p)
#'
#' @export
xgb.plot.multi.trees <- function(model, names, features.keep = 5, plot.width = NULL, plot.height = NULL){
tree.matrix <- xgb.model.dt.tree(names, model = model)
# first number of the path represents the tree, then the following numbers are related to the path to follow
# root init
root.nodes <- tree.matrix[str_detect(ID, "\\d+-0"), ID]
tree.matrix[ID %in% root.nodes, abs.node.position:=root.nodes]
precedent.nodes <- root.nodes
while(tree.matrix[,sum(is.na(abs.node.position))] > 0) {
yes.row.nodes <- tree.matrix[abs.node.position %in% precedent.nodes & !is.na(Yes)]
no.row.nodes <- tree.matrix[abs.node.position %in% precedent.nodes & !is.na(No)]
yes.nodes.abs.pos <- yes.row.nodes[, abs.node.position] %>% paste0("_0")
no.nodes.abs.pos <- no.row.nodes[, abs.node.position] %>% paste0("_1")
tree.matrix[ID %in% yes.row.nodes[, Yes], abs.node.position := yes.nodes.abs.pos]
tree.matrix[ID %in% no.row.nodes[, No], abs.node.position := no.nodes.abs.pos]
precedent.nodes <- c(yes.nodes.abs.pos, no.nodes.abs.pos)
}
tree.matrix[!is.na(Yes),Yes:= paste0(abs.node.position, "_0")]
tree.matrix[!is.na(No),No:= paste0(abs.node.position, "_1")]
remove.tree <- . %>% str_replace(pattern = "^\\d+-", replacement = "")
tree.matrix[,`:=`(abs.node.position=remove.tree(abs.node.position), Yes=remove.tree(Yes), No=remove.tree(No))]
nodes.dt <- tree.matrix[,.(Quality = sum(Quality)),by = .(abs.node.position, Feature)][,.(Text =paste0(Feature[1:min(length(Feature), features.keep)], " (", Quality[1:min(length(Quality), features.keep)], ")") %>% paste0(collapse = "\n")), by=abs.node.position]
edges.dt <- tree.matrix[Feature != "Leaf",.(abs.node.position, Yes)] %>% list(tree.matrix[Feature != "Leaf",.(abs.node.position, No)]) %>% rbindlist() %>% setnames(c("From", "To")) %>% .[,.N,.(From, To)] %>% .[,N:=NULL]
nodes <- DiagrammeR::create_nodes(nodes = nodes.dt[,abs.node.position],
label = nodes.dt[,Text],
style = "filled",
color = "DimGray",
fillcolor= "Beige",
shape = "oval",
fontname = "Helvetica"
)
edges <- DiagrammeR::create_edges(from = edges.dt[,From],
to = edges.dt[,To],
color = "DimGray",
arrowsize = "1.5",
arrowhead = "vee",
fontname = "Helvetica",
rel = "leading_to")
graph <- DiagrammeR::create_graph(nodes_df = nodes,
edges_df = edges,
graph_attrs = "rankdir = LR")
DiagrammeR::render_graph(graph, width = plot.width, height = plot.height)
}
globalVariables(
c(
"Feature", "no.nodes.abs.pos", "ID", "Yes", "No", "Tree", "yes.nodes.abs.pos", "abs.node.position"
)
)

View File

@ -10,8 +10,8 @@
#' @param filename_dump the path to the text file storing the model. Model dump must include the gain per feature and per tree (parameter \code{with.stats = T} in function \code{xgb.dump}). Possible to provide a model directly (see \code{model} argument).
#' @param model generated by the \code{xgb.train} function. Avoid the creation of a dump file.
#' @param n_first_tree limit the plot to the n first trees. If \code{NULL}, all trees of the model are plotted. Performance can be low for huge models.
#' @param width the width of the diagram in pixels.
#' @param height the height of the diagram in pixels.
#' @param plot.width the width of the diagram in pixels.
#' @param plot.height the height of the diagram in pixels.
#'
#' @return A \code{DiagrammeR} of the model.
#'
@ -43,7 +43,7 @@
#' xgb.plot.tree(agaricus.train$data@@Dimnames[[2]], model = bst)
#'
#' @export
xgb.plot.tree <- function(feature_names = NULL, filename_dump = NULL, model = NULL, n_first_tree = NULL, width = NULL, height = NULL){
xgb.plot.tree <- function(feature_names = NULL, filename_dump = NULL, model = NULL, n_first_tree = NULL, plot.width = NULL, plot.height = NULL){
if (!class(model) %in% c("xgb.Booster", "NULL")) {
stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.")
@ -87,7 +87,7 @@ xgb.plot.tree <- function(feature_names = NULL, filename_dump = NULL, model = NU
edges_df = edges,
graph_attrs = "rankdir = LR")
DiagrammeR::render_graph(graph, width = width, height = height)
DiagrammeR::render_graph(graph, width = plot.width, height = plot.height)
}
# Avoid error messages during CRAN check.

View File

@ -0,0 +1,56 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/xgb.plot.multi.trees.R
\name{xgb.plot.multi.trees}
\alias{xgb.plot.multi.trees}
\title{Project all trees on one tree and plot it}
\usage{
xgb.plot.multi.trees(model, names, features.keep = 5, plot.width = NULL,
plot.height = NULL)
}
\arguments{
\item{model}{dump generated by the \code{xgb.train} function. Avoid the creation of a dump file.}
\item{features.keep}{number of features to keep in each position of the multi tree.}
\item{plot.width}{width in pixels of the graph to produce}
\item{plot.height}{height in pixels of the graph to produce}
\item{filename_dump}{the path to the text file storing the model. Model dump must include the gain per feature and per tree (parameter \code{with.stats = T} in function \code{xgb.dump}).}
}
\value{
Two graphs showing the distribution of the model deepness.
}
\description{
Visualization of the ensemble of trees as a single collective unit.
}
\details{
This function tries to capture the complexity of gradient boosted tree ensembles
in a cohesive way.
The goal is to improve the interpretability of the model generally seen as black box.
The function is dedicated to boosting applied to decision trees only.
The purpose is to move from an ensemble of trees to a single tree only.
It takes advantage of the fact that the shape of a binary tree is only defined by
its deepness.
Therefore in a boosting model, all trees have the same shape.
Moreover, the trees tend to reuse the same features.
The function will project each trees on one, and keep for each position the
\code{features.keep} first features (based on Gain per feature).
This function is inspired from this blog post:
\url{https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/}
}
\examples{
data(agaricus.train, package='xgboost')
bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max.depth = 15,
eta = 1, nthread = 2, nround = 30, objective = "binary:logistic",
min_child_weight = 50)
p <- xgb.plot.multi.trees(bst, agaricus.train$data@Dimnames[[2]], 3)
print(p)
}

View File

@ -5,7 +5,7 @@
\title{Plot a boosted tree model}
\usage{
xgb.plot.tree(feature_names = NULL, filename_dump = NULL, model = NULL,
n_first_tree = NULL, width = NULL, height = NULL)
n_first_tree = NULL, plot.width = NULL, plot.height = NULL)
}
\arguments{
\item{feature_names}{names of each feature as a character vector. Can be extracted from a sparse matrix (see example). If model dump already contains feature names, this argument should be \code{NULL}.}
@ -16,9 +16,9 @@ xgb.plot.tree(feature_names = NULL, filename_dump = NULL, model = NULL,
\item{n_first_tree}{limit the plot to the n first trees. If \code{NULL}, all trees of the model are plotted. Performance can be low for huge models.}
\item{width}{the width of the diagram in pixels.}
\item{plot.width}{the width of the diagram in pixels.}
\item{height}{the height of the diagram in pixels.}
\item{plot.height}{the height of the diagram in pixels.}
}
\value{
A \code{DiagrammeR} of the model.