[R] Enable 'dot' dump format (#9930)

This commit is contained in:
david-cortes 2023-12-30 06:28:27 +01:00 committed by GitHub
parent ef8bdaa047
commit e40c4260ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 100 additions and 9 deletions

View File

@ -13,7 +13,10 @@
#' When this option is on, the model dump contains two additional values:
#' gain is the approximate loss function gain we get in each split;
#' cover is the sum of second order gradient in each node.
#' @param dump_format either 'text' or 'json' format could be specified.
#' @param dump_format either 'text', 'json', or 'dot' (graphviz) format could be specified.
#'
#' Format 'dot' for a single tree can be passed directly to packages that consume this format
#' for graph visualization, such as function [DiagrammeR::grViz()]
#' @param ... currently not used
#'
#' @return
@ -37,9 +40,13 @@
#' # print in JSON format:
#' cat(xgb.dump(bst, with_stats = TRUE, dump_format='json'))
#'
#' # plot first tree leveraging the 'dot' format
#' if (requireNamespace('DiagrammeR', quietly = TRUE)) {
#' DiagrammeR::grViz(xgb.dump(bst, dump_format = "dot")[[1L]])
#' }
#' @export
xgb.dump <- function(model, fname = NULL, fmap = "", with_stats = FALSE,
dump_format = c("text", "json"), ...) {
dump_format = c("text", "json", "dot"), ...) {
check.deprecation(...)
dump_format <- match.arg(dump_format)
if (!inherits(model, "xgb.Booster"))
@ -52,6 +59,9 @@ xgb.dump <- function(model, fname = NULL, fmap = "", with_stats = FALSE,
model <- xgb.Booster.complete(model)
model_dump <- .Call(XGBoosterDumpModel_R, model$handle, NVL(fmap, "")[1], as.integer(with_stats),
as.character(dump_format))
if (dump_format == "dot") {
return(sapply(model_dump, function(x) gsub("^booster\\[\\d+\\]\\n", "\\1", x)))
}
if (is.null(fname))
model_dump <- gsub('\t', '', model_dump, fixed = TRUE)

View File

@ -14,11 +14,33 @@
#' The values are passed to [DiagrammeR::render_graph()].
#' @param render Should the graph be rendered or not? The default is `TRUE`.
#' @param show_node_id a logical flag for whether to show node id's in the graph.
#' @param style Style to use for the plot. Options are:\itemize{
#' \item `"xgboost"`: will use the plot style defined in the core XGBoost library,
#' which is shared between different interfaces through the 'dot' format. This
#' style was not available before version 2.1.0 in R. It always plots the trees
#' vertically (from top to bottom).
#' \item `"R"`: will use the style defined from XGBoost's R interface, which predates
#' the introducition of the standardized style from the core library. It might plot
#' the trees horizontally (from left to right).
#' }
#'
#' Note that `style="xgboost"` is only supported when all of the following conditions are met:\itemize{
#' \item Only a single tree is being plotted.
#' \item Node IDs are not added to the graph.
#' \item The graph is being returned as `htmlwidget` (`render=TRUE`).
#' }
#' @param ... currently not used.
#'
#' @details
#'
#' The content of each node is visualized like this:
#' When using `style="xgboost"`, the content of each node is visualized as follows:
#' - For non-terminal nodes, it will display the split condition (number or name if
#' available, and the condition that would decide to which node to go next).
#' - Those nodes will be connected to their children by arrows that indicate whether the
#' branch corresponds to the condition being met or not being met.
#' - Terminal (leaf) nodes contain the margin to add when ending there.
#'
#' When using `style="R"`, the content of each node is visualized like this:
#' - *Feature name*.
#' - *Cover:* The sum of second order gradients of training data.
#' For the squared loss, this simply corresponds to the number of instances in the node.
@ -57,8 +79,13 @@
#' objective = "binary:logistic"
#' )
#'
#' # plot the first tree, using the style from xgboost's core library
#' # (this plot should look identical to the ones generated from other
#' # interfaces like the python package for xgboost)
#' xgb.plot.tree(model = bst, trees = 1, style = "xgboost")
#'
#' # plot all the trees
#' xgb.plot.tree(model = bst)
#' xgb.plot.tree(model = bst, trees = NULL)
#'
#' # plot only the first tree and display the node ID:
#' xgb.plot.tree(model = bst, trees = 0, show_node_id = TRUE)
@ -77,7 +104,7 @@
#'
#' @export
xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot_width = NULL, plot_height = NULL,
render = TRUE, show_node_id = FALSE, ...) {
render = TRUE, show_node_id = FALSE, style = c("R", "xgboost"), ...) {
check.deprecation(...)
if (!inherits(model, "xgb.Booster")) {
stop("model: Has to be an object of class xgb.Booster")
@ -87,6 +114,22 @@ xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot
stop("DiagrammeR package is required for xgb.plot.tree", call. = FALSE)
}
style <- as.character(head(style, 1L))
stopifnot(style %in% c("R", "xgboost"))
if (style == "xgboost") {
if (NROW(trees) != 1L || !render || show_node_id) {
stop("style='xgboost' is only supported for single, rendered tree, without node IDs.")
}
if (!is.null(feature_names)) {
stop(
"style='xgboost' cannot override 'feature_names'. Will automatically take them from the model."
)
}
txt <- xgb.dump(model, dump_format = "dot")
return(DiagrammeR::grViz(txt[[trees + 1]], width = plot_width, height = plot_height))
}
dt <- xgb.model.dt.tree(feature_names = feature_names, model = model, trees = trees)
dt[, label := paste0(Feature, "\nCover: ", Cover, ifelse(Feature == "Leaf", "\nValue: ", "\nGain: "), Quality)]

View File

@ -9,7 +9,7 @@ xgb.dump(
fname = NULL,
fmap = "",
with_stats = FALSE,
dump_format = c("text", "json"),
dump_format = c("text", "json", "dot"),
...
)
}
@ -29,7 +29,10 @@ When this option is on, the model dump contains two additional values:
gain is the approximate loss function gain we get in each split;
cover is the sum of second order gradient in each node.}
\item{dump_format}{either 'text' or 'json' format could be specified.}
\item{dump_format}{either 'text', 'json', or 'dot' (graphviz) format could be specified.
Format 'dot' for a single tree can be passed directly to packages that consume this format
for graph visualization, such as function \code{\link[DiagrammeR:grViz]{DiagrammeR::grViz()}}}
\item{...}{currently not used}
}
@ -57,4 +60,8 @@ print(xgb.dump(bst, with_stats = TRUE))
# print in JSON format:
cat(xgb.dump(bst, with_stats = TRUE, dump_format='json'))
# plot first tree leveraging the 'dot' format
if (requireNamespace('DiagrammeR', quietly = TRUE)) {
DiagrammeR::grViz(xgb.dump(bst, dump_format = "dot")[[1L]])
}
}

View File

@ -12,6 +12,7 @@ xgb.plot.tree(
plot_height = NULL,
render = TRUE,
show_node_id = FALSE,
style = c("R", "xgboost"),
...
)
}
@ -34,6 +35,22 @@ The values are passed to \code{\link[DiagrammeR:render_graph]{DiagrammeR::render
\item{show_node_id}{a logical flag for whether to show node id's in the graph.}
\item{style}{Style to use for the plot. Options are:\itemize{
\item \code{"xgboost"}: will use the plot style defined in the core XGBoost library,
which is shared between different interfaces through the 'dot' format. This
style was not available before version 2.1.0 in R. It always plots the trees
vertically (from top to bottom).
\item \code{"R"}: will use the style defined from XGBoost's R interface, which predates
the introducition of the standardized style from the core library. It might plot
the trees horizontally (from left to right).
}
Note that \code{style="xgboost"} is only supported when all of the following conditions are met:\itemize{
\item Only a single tree is being plotted.
\item Node IDs are not added to the graph.
\item The graph is being returned as \code{htmlwidget} (\code{render=TRUE}).
}}
\item{...}{currently not used.}
}
\value{
@ -51,7 +68,16 @@ before rendering the graph with \code{\link[DiagrammeR:render_graph]{DiagrammeR:
Read a tree model text dump and plot the model.
}
\details{
The content of each node is visualized like this:
When using \code{style="xgboost"}, the content of each node is visualized as follows:
\itemize{
\item For non-terminal nodes, it will display the split condition (number or name if
available, and the condition that would decide to which node to go next).
\item Those nodes will be connected to their children by arrows that indicate whether the
branch corresponds to the condition being met or not being met.
\item Terminal (leaf) nodes contain the margin to add when ending there.
}
When using \code{style="R"}, the content of each node is visualized like this:
\itemize{
\item \emph{Feature name}.
\item \emph{Cover:} The sum of second order gradients of training data.
@ -83,8 +109,13 @@ bst <- xgboost(
objective = "binary:logistic"
)
# plot the first tree, using the style from xgboost's core library
# (this plot should look identical to the ones generated from other
# interfaces like the python package for xgboost)
xgb.plot.tree(model = bst, trees = 1, style = "xgboost")
# plot all the trees
xgb.plot.tree(model = bst)
xgb.plot.tree(model = bst, trees = NULL)
# plot only the first tree and display the node ID:
xgb.plot.tree(model = bst, trees = 0, show_node_id = TRUE)