[R] maintenance Nov 2017; SHAP plots (#2888)
* [R] fix predict contributions for data with no colnames * [R] add a render parameter for xgb.plot.multi.trees; fixes #2628 * [R] update Rd's * [R] remove unnecessary dep-package from R cmake install * silence type warnings; readability * [R] silence complaint about incomplete line at the end * [R] initial version of xgb.plot.shap() * [R] more work on xgb.plot.shap * [R] enforce black font in xgb.plot.tree; fixes #2640 * [R] if feature names are available, check in predict that they are the same; fixes #2857 * [R] cran check and lint fixes * remove tabs * [R] add references; a test for plot.shap
This commit is contained in:
committed by
Tong He
parent
1b77903eeb
commit
e8a6597957
@@ -1,5 +1,5 @@
|
||||
#' Project all trees on one tree and plot it
|
||||
#'
|
||||
#'
|
||||
#' Visualization of the ensemble of trees as a single collective unit.
|
||||
#'
|
||||
#' @param model produced by the \code{xgb.train} function.
|
||||
@@ -7,52 +7,71 @@
|
||||
#' @param features_keep number of features to keep in each position of the multi trees.
|
||||
#' @param plot_width width in pixels of the graph to produce
|
||||
#' @param plot_height height in pixels of the graph to produce
|
||||
#' @param render a logical flag for whether the graph should be rendered (see Value).
|
||||
#' @param ... currently not used
|
||||
#'
|
||||
#' @return Two graphs showing the distribution of the model deepness.
|
||||
#'
|
||||
#' @details
|
||||
#'
|
||||
#' This function tries to capture the complexity of a gradient boosted tree model
|
||||
#'
|
||||
#' This function tries to capture the complexity of a gradient boosted tree model
|
||||
#' in a cohesive way by compressing an ensemble of trees into a single tree-graph representation.
|
||||
#' The goal is to improve the interpretability of a model generally seen as black box.
|
||||
#'
|
||||
#'
|
||||
#' Note: this function is applicable to tree booster-based models only.
|
||||
#'
|
||||
#' It takes advantage of the fact that the shape of a binary tree is only defined by
|
||||
#' its depth (therefore, in a boosting model, all trees have similar shape).
|
||||
#'
|
||||
#'
|
||||
#' It takes advantage of the fact that the shape of a binary tree is only defined by
|
||||
#' its depth (therefore, in a boosting model, all trees have similar shape).
|
||||
#'
|
||||
#' Moreover, the trees tend to reuse the same features.
|
||||
#'
|
||||
#' The function projects each tree onto one, and keeps for each position the
|
||||
#'
|
||||
#' The function projects each tree onto one, and keeps for each position the
|
||||
#' \code{features_keep} first features (based on the Gain per feature measure).
|
||||
#'
|
||||
#'
|
||||
#' This function is inspired by this blog post:
|
||||
#' \url{https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/}
|
||||
#'
|
||||
#' @return
|
||||
#'
|
||||
#' When \code{render = TRUE}:
|
||||
#' returns a rendered graph object which is an \code{htmlwidget} of class \code{grViz}.
|
||||
#' Similar to ggplot objects, it needs to be printed to see it when not running from command line.
|
||||
#'
|
||||
#' When \code{render = FALSE}:
|
||||
#' silently returns a graph object which is of DiagrammeR's class \code{dgr_graph}.
|
||||
#' This could be useful if one wants to modify some of the graph attributes
|
||||
#' before rendering the graph with \code{\link[DiagrammeR]{render_graph}}.
|
||||
#'
|
||||
#' @examples
|
||||
#'
|
||||
#' data(agaricus.train, package='xgboost')
|
||||
#'
|
||||
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 15,
|
||||
#' eta = 1, nthread = 2, nrounds = 30, objective = "binary:logistic",
|
||||
#' min_child_weight = 50)
|
||||
#' eta = 1, nthread = 2, nrounds = 30, objective = "binary:logistic",
|
||||
#' min_child_weight = 50, verbose = 0)
|
||||
#'
|
||||
#' p <- xgb.plot.multi.trees(model = bst, feature_names = colnames(agaricus.train$data),
|
||||
#' features_keep = 3)
|
||||
#' p <- xgb.plot.multi.trees(model = bst, features_keep = 3)
|
||||
#' print(p)
|
||||
#'
|
||||
#' \dontrun{
|
||||
#' # Below is an example of how to save this plot to a file.
|
||||
#' # Note that for `export_graph` to work, the DiagrammeRsvg and rsvg packages must also be installed.
|
||||
#' library(DiagrammeR)
|
||||
#' gr <- xgb.plot.multi.trees(model=bst, features_keep = 3, render=FALSE)
|
||||
#' export_graph(gr, 'tree.pdf', width=1500, height=600)
|
||||
#' }
|
||||
#'
|
||||
#' @export
|
||||
xgb.plot.multi.trees <- function(model, feature_names = NULL, features_keep = 5, plot_width = NULL, plot_height = NULL, ...){
|
||||
xgb.plot.multi.trees <- function(model, feature_names = NULL, features_keep = 5, plot_width = NULL, plot_height = NULL,
|
||||
render = TRUE, ...){
|
||||
check.deprecation(...)
|
||||
tree.matrix <- xgb.model.dt.tree(feature_names = feature_names, model = model)
|
||||
|
||||
|
||||
# first number of the path represents the tree, then the following numbers are related to the path to follow
|
||||
# root init
|
||||
root.nodes <- tree.matrix[stri_detect_regex(ID, "\\d+-0"), ID]
|
||||
tree.matrix[ID %in% root.nodes, abs.node.position:=root.nodes]
|
||||
|
||||
tree.matrix[ID %in% root.nodes, abs.node.position := root.nodes]
|
||||
|
||||
precedent.nodes <- root.nodes
|
||||
|
||||
|
||||
while(tree.matrix[,sum(is.na(abs.node.position))] > 0) {
|
||||
yes.row.nodes <- tree.matrix[abs.node.position %in% precedent.nodes & !is.na(Yes)]
|
||||
no.row.nodes <- tree.matrix[abs.node.position %in% precedent.nodes & !is.na(No)]
|
||||
@@ -64,9 +83,8 @@ xgb.plot.multi.trees <- function(model, feature_names = NULL, features_keep = 5,
|
||||
precedent.nodes <- c(yes.nodes.abs.pos, no.nodes.abs.pos)
|
||||
}
|
||||
|
||||
tree.matrix[!is.na(Yes),Yes:= paste0(abs.node.position, "_0")]
|
||||
tree.matrix[!is.na(No),No:= paste0(abs.node.position, "_1")]
|
||||
|
||||
tree.matrix[!is.na(Yes), Yes := paste0(abs.node.position, "_0")]
|
||||
tree.matrix[!is.na(No), No := paste0(abs.node.position, "_1")]
|
||||
|
||||
remove.tree <- . %>% stri_replace_first_regex(pattern = "^\\d+-", replacement = "")
|
||||
|
||||
@@ -120,8 +138,10 @@ xgb.plot.multi.trees <- function(model, feature_names = NULL, features_keep = 5,
|
||||
attr_type = "edge",
|
||||
attr = c("color", "arrowsize", "arrowhead", "fontname"),
|
||||
value = c("DimGray", "1.5", "vee", "Helvetica"))
|
||||
|
||||
DiagrammeR::render_graph(graph, width = plot_width, height = plot_height)
|
||||
|
||||
if (!render) return(invisible(graph))
|
||||
|
||||
DiagrammeR::render_graph(graph, width = plot_width, height = plot_height)
|
||||
}
|
||||
|
||||
globalVariables(c(".N", "N", "From", "To", "Text", "Feature", "no.nodes.abs.pos",
|
||||
|
||||
Reference in New Issue
Block a user