Merge pull request #655 from pommedeterresautee/master
Add new multi tree plot function to R package
This commit is contained in:
commit
bf19d821e0
@ -12,6 +12,7 @@ export(xgb.load)
|
||||
export(xgb.model.dt.tree)
|
||||
export(xgb.plot.deepness)
|
||||
export(xgb.plot.importance)
|
||||
export(xgb.plot.multi.trees)
|
||||
export(xgb.plot.tree)
|
||||
export(xgb.save)
|
||||
export(xgb.save.raw)
|
||||
@ -36,6 +37,7 @@ importFrom(data.table,setnames)
|
||||
importFrom(magrittr,"%>%")
|
||||
importFrom(magrittr,add)
|
||||
importFrom(magrittr,not)
|
||||
importFrom(stringr,str_detect)
|
||||
importFrom(stringr,str_extract)
|
||||
importFrom(stringr,str_extract_all)
|
||||
importFrom(stringr,str_match)
|
||||
|
||||
@ -12,7 +12,6 @@
|
||||
#' @importFrom magrittr add
|
||||
#' @importFrom stringr str_extract
|
||||
#' @importFrom stringr str_split
|
||||
#' @importFrom stringr str_extract
|
||||
#' @importFrom stringr str_trim
|
||||
#' @param feature_names names of each feature as a character vector. Can be extracted from a sparse matrix (see example). If model dump already contains feature names, this argument should be \code{NULL}.
|
||||
#' @param filename_dump the path to the text file storing the model. Model dump must include the gain per feature and per tree (parameter \code{with.stats = T} in function \code{xgb.dump}).
|
||||
|
||||
112
R-package/R/xgb.plot.multi.trees.R
Normal file
112
R-package/R/xgb.plot.multi.trees.R
Normal file
@ -0,0 +1,112 @@
|
||||
#' Project all trees on one tree and plot it
|
||||
#'
|
||||
#' Visualization of the ensemble of trees as a single collective unit.
|
||||
#'
|
||||
#' @importFrom data.table data.table
|
||||
#' @importFrom data.table rbindlist
|
||||
#' @importFrom data.table setnames
|
||||
#' @importFrom data.table :=
|
||||
#' @importFrom magrittr %>%
|
||||
#' @importFrom stringr str_detect
|
||||
#' @importFrom stringr str_extract
|
||||
#'
|
||||
#' @param filename_dump the path to the text file storing the model. Model dump must include the gain per feature and per tree (parameter \code{with.stats = T} in function \code{xgb.dump}).
|
||||
#' @param model dump generated by the \code{xgb.train} function. Avoid the creation of a dump file.
|
||||
#' @param features.keep number of features to keep in each position of the multi tree.
|
||||
#' @param plot.width width in pixels of the graph to produce
|
||||
#' @param plot.height height in pixels of the graph to produce
|
||||
#'
|
||||
#' @return Two graphs showing the distribution of the model deepness.
|
||||
#'
|
||||
#' @details
|
||||
#'
|
||||
#' This function tries to capture the complexity of gradient boosted tree ensembles
|
||||
#' in a cohesive way.
|
||||
#' The goal is to improve the interpretability of the model generally seen as black box.
|
||||
#' The function is dedicated to boosting applied to decision trees only.
|
||||
#'
|
||||
#' The purpose is to move from an ensemble of trees to a single tree only.
|
||||
#' It takes advantage of the fact that the shape of a binary tree is only defined by
|
||||
#' its deepness.
|
||||
#' Therefore in a boosting model, all trees have the same shape.
|
||||
#' Moreover, the trees tend to reuse the same features.
|
||||
#'
|
||||
#' The function will project each trees on one, and keep for each position the
|
||||
#' \code{features.keep} first features (based on Gain per feature).
|
||||
#'
|
||||
#' This function is inspired from this blog post:
|
||||
#' \url{https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/}
|
||||
#'
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='xgboost')
|
||||
#'
|
||||
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max.depth = 15,
|
||||
#' eta = 1, nthread = 2, nround = 30, objective = "binary:logistic",
|
||||
#' min_child_weight = 50)
|
||||
#'
|
||||
#' p <- xgb.plot.multi.trees(bst, agaricus.train$data@Dimnames[[2]], 3)
|
||||
#' print(p)
|
||||
#'
|
||||
#' @export
|
||||
xgb.plot.multi.trees <- function(model, names, features.keep = 5, plot.width = NULL, plot.height = NULL){
|
||||
tree.matrix <- xgb.model.dt.tree(names, model = model)
|
||||
|
||||
# first number of the path represents the tree, then the following numbers are related to the path to follow
|
||||
# root init
|
||||
root.nodes <- tree.matrix[str_detect(ID, "\\d+-0"), ID]
|
||||
tree.matrix[ID %in% root.nodes, abs.node.position:=root.nodes]
|
||||
|
||||
precedent.nodes <- root.nodes
|
||||
|
||||
while(tree.matrix[,sum(is.na(abs.node.position))] > 0) {
|
||||
yes.row.nodes <- tree.matrix[abs.node.position %in% precedent.nodes & !is.na(Yes)]
|
||||
no.row.nodes <- tree.matrix[abs.node.position %in% precedent.nodes & !is.na(No)]
|
||||
yes.nodes.abs.pos <- yes.row.nodes[, abs.node.position] %>% paste0("_0")
|
||||
no.nodes.abs.pos <- no.row.nodes[, abs.node.position] %>% paste0("_1")
|
||||
|
||||
tree.matrix[ID %in% yes.row.nodes[, Yes], abs.node.position := yes.nodes.abs.pos]
|
||||
tree.matrix[ID %in% no.row.nodes[, No], abs.node.position := no.nodes.abs.pos]
|
||||
precedent.nodes <- c(yes.nodes.abs.pos, no.nodes.abs.pos)
|
||||
}
|
||||
|
||||
tree.matrix[!is.na(Yes),Yes:= paste0(abs.node.position, "_0")]
|
||||
tree.matrix[!is.na(No),No:= paste0(abs.node.position, "_1")]
|
||||
|
||||
|
||||
|
||||
remove.tree <- . %>% str_replace(pattern = "^\\d+-", replacement = "")
|
||||
|
||||
tree.matrix[,`:=`(abs.node.position=remove.tree(abs.node.position), Yes=remove.tree(Yes), No=remove.tree(No))]
|
||||
|
||||
nodes.dt <- tree.matrix[,.(Quality = sum(Quality)),by = .(abs.node.position, Feature)][,.(Text =paste0(Feature[1:min(length(Feature), features.keep)], " (", Quality[1:min(length(Quality), features.keep)], ")") %>% paste0(collapse = "\n")), by=abs.node.position]
|
||||
edges.dt <- tree.matrix[Feature != "Leaf",.(abs.node.position, Yes)] %>% list(tree.matrix[Feature != "Leaf",.(abs.node.position, No)]) %>% rbindlist() %>% setnames(c("From", "To")) %>% .[,.N,.(From, To)] %>% .[,N:=NULL]
|
||||
|
||||
nodes <- DiagrammeR::create_nodes(nodes = nodes.dt[,abs.node.position],
|
||||
label = nodes.dt[,Text],
|
||||
style = "filled",
|
||||
color = "DimGray",
|
||||
fillcolor= "Beige",
|
||||
shape = "oval",
|
||||
fontname = "Helvetica"
|
||||
)
|
||||
|
||||
edges <- DiagrammeR::create_edges(from = edges.dt[,From],
|
||||
to = edges.dt[,To],
|
||||
color = "DimGray",
|
||||
arrowsize = "1.5",
|
||||
arrowhead = "vee",
|
||||
fontname = "Helvetica",
|
||||
rel = "leading_to")
|
||||
|
||||
graph <- DiagrammeR::create_graph(nodes_df = nodes,
|
||||
edges_df = edges,
|
||||
graph_attrs = "rankdir = LR")
|
||||
|
||||
DiagrammeR::render_graph(graph, width = plot.width, height = plot.height)
|
||||
}
|
||||
|
||||
globalVariables(
|
||||
c(
|
||||
"Feature", "no.nodes.abs.pos", "ID", "Yes", "No", "Tree", "yes.nodes.abs.pos", "abs.node.position"
|
||||
)
|
||||
)
|
||||
@ -10,8 +10,8 @@
|
||||
#' @param filename_dump the path to the text file storing the model. Model dump must include the gain per feature and per tree (parameter \code{with.stats = T} in function \code{xgb.dump}). Possible to provide a model directly (see \code{model} argument).
|
||||
#' @param model generated by the \code{xgb.train} function. Avoid the creation of a dump file.
|
||||
#' @param n_first_tree limit the plot to the n first trees. If \code{NULL}, all trees of the model are plotted. Performance can be low for huge models.
|
||||
#' @param width the width of the diagram in pixels.
|
||||
#' @param height the height of the diagram in pixels.
|
||||
#' @param plot.width the width of the diagram in pixels.
|
||||
#' @param plot.height the height of the diagram in pixels.
|
||||
#'
|
||||
#' @return A \code{DiagrammeR} of the model.
|
||||
#'
|
||||
@ -43,7 +43,7 @@
|
||||
#' xgb.plot.tree(agaricus.train$data@@Dimnames[[2]], model = bst)
|
||||
#'
|
||||
#' @export
|
||||
xgb.plot.tree <- function(feature_names = NULL, filename_dump = NULL, model = NULL, n_first_tree = NULL, width = NULL, height = NULL){
|
||||
xgb.plot.tree <- function(feature_names = NULL, filename_dump = NULL, model = NULL, n_first_tree = NULL, plot.width = NULL, plot.height = NULL){
|
||||
|
||||
if (!class(model) %in% c("xgb.Booster", "NULL")) {
|
||||
stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.")
|
||||
@ -87,7 +87,7 @@ xgb.plot.tree <- function(feature_names = NULL, filename_dump = NULL, model = NU
|
||||
edges_df = edges,
|
||||
graph_attrs = "rankdir = LR")
|
||||
|
||||
DiagrammeR::render_graph(graph, width = width, height = height)
|
||||
DiagrammeR::render_graph(graph, width = plot.width, height = plot.height)
|
||||
}
|
||||
|
||||
# Avoid error messages during CRAN check.
|
||||
|
||||
56
R-package/man/xgb.plot.multi.trees.Rd
Normal file
56
R-package/man/xgb.plot.multi.trees.Rd
Normal file
@ -0,0 +1,56 @@
|
||||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/xgb.plot.multi.trees.R
|
||||
\name{xgb.plot.multi.trees}
|
||||
\alias{xgb.plot.multi.trees}
|
||||
\title{Project all trees on one tree and plot it}
|
||||
\usage{
|
||||
xgb.plot.multi.trees(model, names, features.keep = 5, plot.width = NULL,
|
||||
plot.height = NULL)
|
||||
}
|
||||
\arguments{
|
||||
\item{model}{dump generated by the \code{xgb.train} function. Avoid the creation of a dump file.}
|
||||
|
||||
\item{features.keep}{number of features to keep in each position of the multi tree.}
|
||||
|
||||
\item{plot.width}{width in pixels of the graph to produce}
|
||||
|
||||
\item{plot.height}{height in pixels of the graph to produce}
|
||||
|
||||
\item{filename_dump}{the path to the text file storing the model. Model dump must include the gain per feature and per tree (parameter \code{with.stats = T} in function \code{xgb.dump}).}
|
||||
}
|
||||
\value{
|
||||
Two graphs showing the distribution of the model deepness.
|
||||
}
|
||||
\description{
|
||||
Visualization of the ensemble of trees as a single collective unit.
|
||||
}
|
||||
\details{
|
||||
This function tries to capture the complexity of gradient boosted tree ensembles
|
||||
in a cohesive way.
|
||||
The goal is to improve the interpretability of the model generally seen as black box.
|
||||
The function is dedicated to boosting applied to decision trees only.
|
||||
|
||||
The purpose is to move from an ensemble of trees to a single tree only.
|
||||
It takes advantage of the fact that the shape of a binary tree is only defined by
|
||||
its deepness.
|
||||
Therefore in a boosting model, all trees have the same shape.
|
||||
Moreover, the trees tend to reuse the same features.
|
||||
|
||||
The function will project each trees on one, and keep for each position the
|
||||
\code{features.keep} first features (based on Gain per feature).
|
||||
|
||||
This function is inspired from this blog post:
|
||||
\url{https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/}
|
||||
}
|
||||
\examples{
|
||||
data(agaricus.train, package='xgboost')
|
||||
|
||||
bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max.depth = 15,
|
||||
eta = 1, nthread = 2, nround = 30, objective = "binary:logistic",
|
||||
min_child_weight = 50)
|
||||
|
||||
p <- xgb.plot.multi.trees(bst, agaricus.train$data@Dimnames[[2]], 3)
|
||||
print(p)
|
||||
|
||||
}
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
\title{Plot a boosted tree model}
|
||||
\usage{
|
||||
xgb.plot.tree(feature_names = NULL, filename_dump = NULL, model = NULL,
|
||||
n_first_tree = NULL, width = NULL, height = NULL)
|
||||
n_first_tree = NULL, plot.width = NULL, plot.height = NULL)
|
||||
}
|
||||
\arguments{
|
||||
\item{feature_names}{names of each feature as a character vector. Can be extracted from a sparse matrix (see example). If model dump already contains feature names, this argument should be \code{NULL}.}
|
||||
@ -16,9 +16,9 @@ xgb.plot.tree(feature_names = NULL, filename_dump = NULL, model = NULL,
|
||||
|
||||
\item{n_first_tree}{limit the plot to the n first trees. If \code{NULL}, all trees of the model are plotted. Performance can be low for huge models.}
|
||||
|
||||
\item{width}{the width of the diagram in pixels.}
|
||||
\item{plot.width}{the width of the diagram in pixels.}
|
||||
|
||||
\item{height}{the height of the diagram in pixels.}
|
||||
\item{plot.height}{the height of the diagram in pixels.}
|
||||
}
|
||||
\value{
|
||||
A \code{DiagrammeR} of the model.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user