From e40c4260edddc0f96d640ca0e762d4e65fd57a8a Mon Sep 17 00:00:00 2001 From: david-cortes Date: Sat, 30 Dec 2023 06:28:27 +0100 Subject: [PATCH] [R] Enable 'dot' dump format (#9930) --- R-package/R/xgb.dump.R | 14 ++++++++-- R-package/R/xgb.plot.tree.R | 49 +++++++++++++++++++++++++++++++--- R-package/man/xgb.dump.Rd | 11 ++++++-- R-package/man/xgb.plot.tree.Rd | 35 ++++++++++++++++++++++-- 4 files changed, 100 insertions(+), 9 deletions(-) diff --git a/R-package/R/xgb.dump.R b/R-package/R/xgb.dump.R index a2de26c26..4421836d1 100644 --- a/R-package/R/xgb.dump.R +++ b/R-package/R/xgb.dump.R @@ -13,7 +13,10 @@ #' When this option is on, the model dump contains two additional values: #' gain is the approximate loss function gain we get in each split; #' cover is the sum of second order gradient in each node. -#' @param dump_format either 'text' or 'json' format could be specified. +#' @param dump_format either 'text', 'json', or 'dot' (graphviz) format could be specified. +#' +#' Format 'dot' for a single tree can be passed directly to packages that consume this format +#' for graph visualization, such as function [DiagrammeR::grViz()] #' @param ... currently not used #' #' @return @@ -37,9 +40,13 @@ #' # print in JSON format: #' cat(xgb.dump(bst, with_stats = TRUE, dump_format='json')) #' +#' # plot first tree leveraging the 'dot' format +#' if (requireNamespace('DiagrammeR', quietly = TRUE)) { +#' DiagrammeR::grViz(xgb.dump(bst, dump_format = "dot")[[1L]]) +#' } #' @export xgb.dump <- function(model, fname = NULL, fmap = "", with_stats = FALSE, - dump_format = c("text", "json"), ...) { + dump_format = c("text", "json", "dot"), ...) { check.deprecation(...) dump_format <- match.arg(dump_format) if (!inherits(model, "xgb.Booster")) @@ -52,6 +59,9 @@ xgb.dump <- function(model, fname = NULL, fmap = "", with_stats = FALSE, model <- xgb.Booster.complete(model) model_dump <- .Call(XGBoosterDumpModel_R, model$handle, NVL(fmap, "")[1], as.integer(with_stats), as.character(dump_format)) + if (dump_format == "dot") { + return(sapply(model_dump, function(x) gsub("^booster\\[\\d+\\]\\n", "\\1", x))) + } if (is.null(fname)) model_dump <- gsub('\t', '', model_dump, fixed = TRUE) diff --git a/R-package/R/xgb.plot.tree.R b/R-package/R/xgb.plot.tree.R index 29d00e111..8b12d8a68 100644 --- a/R-package/R/xgb.plot.tree.R +++ b/R-package/R/xgb.plot.tree.R @@ -14,11 +14,33 @@ #' The values are passed to [DiagrammeR::render_graph()]. #' @param render Should the graph be rendered or not? The default is `TRUE`. #' @param show_node_id a logical flag for whether to show node id's in the graph. +#' @param style Style to use for the plot. Options are:\itemize{ +#' \item `"xgboost"`: will use the plot style defined in the core XGBoost library, +#' which is shared between different interfaces through the 'dot' format. This +#' style was not available before version 2.1.0 in R. It always plots the trees +#' vertically (from top to bottom). +#' \item `"R"`: will use the style defined from XGBoost's R interface, which predates +#' the introducition of the standardized style from the core library. It might plot +#' the trees horizontally (from left to right). +#' } +#' +#' Note that `style="xgboost"` is only supported when all of the following conditions are met:\itemize{ +#' \item Only a single tree is being plotted. +#' \item Node IDs are not added to the graph. +#' \item The graph is being returned as `htmlwidget` (`render=TRUE`). +#' } #' @param ... currently not used. #' #' @details #' -#' The content of each node is visualized like this: +#' When using `style="xgboost"`, the content of each node is visualized as follows: +#' - For non-terminal nodes, it will display the split condition (number or name if +#' available, and the condition that would decide to which node to go next). +#' - Those nodes will be connected to their children by arrows that indicate whether the +#' branch corresponds to the condition being met or not being met. +#' - Terminal (leaf) nodes contain the margin to add when ending there. +#' +#' When using `style="R"`, the content of each node is visualized like this: #' - *Feature name*. #' - *Cover:* The sum of second order gradients of training data. #' For the squared loss, this simply corresponds to the number of instances in the node. @@ -57,8 +79,13 @@ #' objective = "binary:logistic" #' ) #' +#' # plot the first tree, using the style from xgboost's core library +#' # (this plot should look identical to the ones generated from other +#' # interfaces like the python package for xgboost) +#' xgb.plot.tree(model = bst, trees = 1, style = "xgboost") +#' #' # plot all the trees -#' xgb.plot.tree(model = bst) +#' xgb.plot.tree(model = bst, trees = NULL) #' #' # plot only the first tree and display the node ID: #' xgb.plot.tree(model = bst, trees = 0, show_node_id = TRUE) @@ -77,7 +104,7 @@ #' #' @export xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot_width = NULL, plot_height = NULL, - render = TRUE, show_node_id = FALSE, ...) { + render = TRUE, show_node_id = FALSE, style = c("R", "xgboost"), ...) { check.deprecation(...) if (!inherits(model, "xgb.Booster")) { stop("model: Has to be an object of class xgb.Booster") @@ -87,6 +114,22 @@ xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot stop("DiagrammeR package is required for xgb.plot.tree", call. = FALSE) } + style <- as.character(head(style, 1L)) + stopifnot(style %in% c("R", "xgboost")) + if (style == "xgboost") { + if (NROW(trees) != 1L || !render || show_node_id) { + stop("style='xgboost' is only supported for single, rendered tree, without node IDs.") + } + if (!is.null(feature_names)) { + stop( + "style='xgboost' cannot override 'feature_names'. Will automatically take them from the model." + ) + } + + txt <- xgb.dump(model, dump_format = "dot") + return(DiagrammeR::grViz(txt[[trees + 1]], width = plot_width, height = plot_height)) + } + dt <- xgb.model.dt.tree(feature_names = feature_names, model = model, trees = trees) dt[, label := paste0(Feature, "\nCover: ", Cover, ifelse(Feature == "Leaf", "\nValue: ", "\nGain: "), Quality)] diff --git a/R-package/man/xgb.dump.Rd b/R-package/man/xgb.dump.Rd index 791e74d96..2cdb6b16a 100644 --- a/R-package/man/xgb.dump.Rd +++ b/R-package/man/xgb.dump.Rd @@ -9,7 +9,7 @@ xgb.dump( fname = NULL, fmap = "", with_stats = FALSE, - dump_format = c("text", "json"), + dump_format = c("text", "json", "dot"), ... ) } @@ -29,7 +29,10 @@ When this option is on, the model dump contains two additional values: gain is the approximate loss function gain we get in each split; cover is the sum of second order gradient in each node.} -\item{dump_format}{either 'text' or 'json' format could be specified.} +\item{dump_format}{either 'text', 'json', or 'dot' (graphviz) format could be specified. + +Format 'dot' for a single tree can be passed directly to packages that consume this format +for graph visualization, such as function \code{\link[DiagrammeR:grViz]{DiagrammeR::grViz()}}} \item{...}{currently not used} } @@ -57,4 +60,8 @@ print(xgb.dump(bst, with_stats = TRUE)) # print in JSON format: cat(xgb.dump(bst, with_stats = TRUE, dump_format='json')) +# plot first tree leveraging the 'dot' format +if (requireNamespace('DiagrammeR', quietly = TRUE)) { + DiagrammeR::grViz(xgb.dump(bst, dump_format = "dot")[[1L]]) +} } diff --git a/R-package/man/xgb.plot.tree.Rd b/R-package/man/xgb.plot.tree.Rd index 7571487eb..a09bb7183 100644 --- a/R-package/man/xgb.plot.tree.Rd +++ b/R-package/man/xgb.plot.tree.Rd @@ -12,6 +12,7 @@ xgb.plot.tree( plot_height = NULL, render = TRUE, show_node_id = FALSE, + style = c("R", "xgboost"), ... ) } @@ -34,6 +35,22 @@ The values are passed to \code{\link[DiagrammeR:render_graph]{DiagrammeR::render \item{show_node_id}{a logical flag for whether to show node id's in the graph.} +\item{style}{Style to use for the plot. Options are:\itemize{ +\item \code{"xgboost"}: will use the plot style defined in the core XGBoost library, +which is shared between different interfaces through the 'dot' format. This +style was not available before version 2.1.0 in R. It always plots the trees +vertically (from top to bottom). +\item \code{"R"}: will use the style defined from XGBoost's R interface, which predates +the introducition of the standardized style from the core library. It might plot +the trees horizontally (from left to right). +} + +Note that \code{style="xgboost"} is only supported when all of the following conditions are met:\itemize{ +\item Only a single tree is being plotted. +\item Node IDs are not added to the graph. +\item The graph is being returned as \code{htmlwidget} (\code{render=TRUE}). +}} + \item{...}{currently not used.} } \value{ @@ -51,7 +68,16 @@ before rendering the graph with \code{\link[DiagrammeR:render_graph]{DiagrammeR: Read a tree model text dump and plot the model. } \details{ -The content of each node is visualized like this: +When using \code{style="xgboost"}, the content of each node is visualized as follows: +\itemize{ +\item For non-terminal nodes, it will display the split condition (number or name if +available, and the condition that would decide to which node to go next). +\item Those nodes will be connected to their children by arrows that indicate whether the +branch corresponds to the condition being met or not being met. +\item Terminal (leaf) nodes contain the margin to add when ending there. +} + +When using \code{style="R"}, the content of each node is visualized like this: \itemize{ \item \emph{Feature name}. \item \emph{Cover:} The sum of second order gradients of training data. @@ -83,8 +109,13 @@ bst <- xgboost( objective = "binary:logistic" ) +# plot the first tree, using the style from xgboost's core library +# (this plot should look identical to the ones generated from other +# interfaces like the python package for xgboost) +xgb.plot.tree(model = bst, trees = 1, style = "xgboost") + # plot all the trees -xgb.plot.tree(model = bst) +xgb.plot.tree(model = bst, trees = NULL) # plot only the first tree and display the node ID: xgb.plot.tree(model = bst, trees = 0, show_node_id = TRUE)