[R] Enable 'dot' dump format (#9930)
This commit is contained in:
parent
ef8bdaa047
commit
e40c4260ed
@ -13,7 +13,10 @@
|
||||
#' When this option is on, the model dump contains two additional values:
|
||||
#' gain is the approximate loss function gain we get in each split;
|
||||
#' cover is the sum of second order gradient in each node.
|
||||
#' @param dump_format either 'text' or 'json' format could be specified.
|
||||
#' @param dump_format either 'text', 'json', or 'dot' (graphviz) format could be specified.
|
||||
#'
|
||||
#' Format 'dot' for a single tree can be passed directly to packages that consume this format
|
||||
#' for graph visualization, such as function [DiagrammeR::grViz()]
|
||||
#' @param ... currently not used
|
||||
#'
|
||||
#' @return
|
||||
@ -37,9 +40,13 @@
|
||||
#' # print in JSON format:
|
||||
#' cat(xgb.dump(bst, with_stats = TRUE, dump_format='json'))
|
||||
#'
|
||||
#' # plot first tree leveraging the 'dot' format
|
||||
#' if (requireNamespace('DiagrammeR', quietly = TRUE)) {
|
||||
#' DiagrammeR::grViz(xgb.dump(bst, dump_format = "dot")[[1L]])
|
||||
#' }
|
||||
#' @export
|
||||
xgb.dump <- function(model, fname = NULL, fmap = "", with_stats = FALSE,
|
||||
dump_format = c("text", "json"), ...) {
|
||||
dump_format = c("text", "json", "dot"), ...) {
|
||||
check.deprecation(...)
|
||||
dump_format <- match.arg(dump_format)
|
||||
if (!inherits(model, "xgb.Booster"))
|
||||
@ -52,6 +59,9 @@ xgb.dump <- function(model, fname = NULL, fmap = "", with_stats = FALSE,
|
||||
model <- xgb.Booster.complete(model)
|
||||
model_dump <- .Call(XGBoosterDumpModel_R, model$handle, NVL(fmap, "")[1], as.integer(with_stats),
|
||||
as.character(dump_format))
|
||||
if (dump_format == "dot") {
|
||||
return(sapply(model_dump, function(x) gsub("^booster\\[\\d+\\]\\n", "\\1", x)))
|
||||
}
|
||||
|
||||
if (is.null(fname))
|
||||
model_dump <- gsub('\t', '', model_dump, fixed = TRUE)
|
||||
|
||||
@ -14,11 +14,33 @@
|
||||
#' The values are passed to [DiagrammeR::render_graph()].
|
||||
#' @param render Should the graph be rendered or not? The default is `TRUE`.
|
||||
#' @param show_node_id a logical flag for whether to show node id's in the graph.
|
||||
#' @param style Style to use for the plot. Options are:\itemize{
|
||||
#' \item `"xgboost"`: will use the plot style defined in the core XGBoost library,
|
||||
#' which is shared between different interfaces through the 'dot' format. This
|
||||
#' style was not available before version 2.1.0 in R. It always plots the trees
|
||||
#' vertically (from top to bottom).
|
||||
#' \item `"R"`: will use the style defined from XGBoost's R interface, which predates
|
||||
#' the introducition of the standardized style from the core library. It might plot
|
||||
#' the trees horizontally (from left to right).
|
||||
#' }
|
||||
#'
|
||||
#' Note that `style="xgboost"` is only supported when all of the following conditions are met:\itemize{
|
||||
#' \item Only a single tree is being plotted.
|
||||
#' \item Node IDs are not added to the graph.
|
||||
#' \item The graph is being returned as `htmlwidget` (`render=TRUE`).
|
||||
#' }
|
||||
#' @param ... currently not used.
|
||||
#'
|
||||
#' @details
|
||||
#'
|
||||
#' The content of each node is visualized like this:
|
||||
#' When using `style="xgboost"`, the content of each node is visualized as follows:
|
||||
#' - For non-terminal nodes, it will display the split condition (number or name if
|
||||
#' available, and the condition that would decide to which node to go next).
|
||||
#' - Those nodes will be connected to their children by arrows that indicate whether the
|
||||
#' branch corresponds to the condition being met or not being met.
|
||||
#' - Terminal (leaf) nodes contain the margin to add when ending there.
|
||||
#'
|
||||
#' When using `style="R"`, the content of each node is visualized like this:
|
||||
#' - *Feature name*.
|
||||
#' - *Cover:* The sum of second order gradients of training data.
|
||||
#' For the squared loss, this simply corresponds to the number of instances in the node.
|
||||
@ -57,8 +79,13 @@
|
||||
#' objective = "binary:logistic"
|
||||
#' )
|
||||
#'
|
||||
#' # plot the first tree, using the style from xgboost's core library
|
||||
#' # (this plot should look identical to the ones generated from other
|
||||
#' # interfaces like the python package for xgboost)
|
||||
#' xgb.plot.tree(model = bst, trees = 1, style = "xgboost")
|
||||
#'
|
||||
#' # plot all the trees
|
||||
#' xgb.plot.tree(model = bst)
|
||||
#' xgb.plot.tree(model = bst, trees = NULL)
|
||||
#'
|
||||
#' # plot only the first tree and display the node ID:
|
||||
#' xgb.plot.tree(model = bst, trees = 0, show_node_id = TRUE)
|
||||
@ -77,7 +104,7 @@
|
||||
#'
|
||||
#' @export
|
||||
xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot_width = NULL, plot_height = NULL,
|
||||
render = TRUE, show_node_id = FALSE, ...) {
|
||||
render = TRUE, show_node_id = FALSE, style = c("R", "xgboost"), ...) {
|
||||
check.deprecation(...)
|
||||
if (!inherits(model, "xgb.Booster")) {
|
||||
stop("model: Has to be an object of class xgb.Booster")
|
||||
@ -87,6 +114,22 @@ xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot
|
||||
stop("DiagrammeR package is required for xgb.plot.tree", call. = FALSE)
|
||||
}
|
||||
|
||||
style <- as.character(head(style, 1L))
|
||||
stopifnot(style %in% c("R", "xgboost"))
|
||||
if (style == "xgboost") {
|
||||
if (NROW(trees) != 1L || !render || show_node_id) {
|
||||
stop("style='xgboost' is only supported for single, rendered tree, without node IDs.")
|
||||
}
|
||||
if (!is.null(feature_names)) {
|
||||
stop(
|
||||
"style='xgboost' cannot override 'feature_names'. Will automatically take them from the model."
|
||||
)
|
||||
}
|
||||
|
||||
txt <- xgb.dump(model, dump_format = "dot")
|
||||
return(DiagrammeR::grViz(txt[[trees + 1]], width = plot_width, height = plot_height))
|
||||
}
|
||||
|
||||
dt <- xgb.model.dt.tree(feature_names = feature_names, model = model, trees = trees)
|
||||
|
||||
dt[, label := paste0(Feature, "\nCover: ", Cover, ifelse(Feature == "Leaf", "\nValue: ", "\nGain: "), Quality)]
|
||||
|
||||
@ -9,7 +9,7 @@ xgb.dump(
|
||||
fname = NULL,
|
||||
fmap = "",
|
||||
with_stats = FALSE,
|
||||
dump_format = c("text", "json"),
|
||||
dump_format = c("text", "json", "dot"),
|
||||
...
|
||||
)
|
||||
}
|
||||
@ -29,7 +29,10 @@ When this option is on, the model dump contains two additional values:
|
||||
gain is the approximate loss function gain we get in each split;
|
||||
cover is the sum of second order gradient in each node.}
|
||||
|
||||
\item{dump_format}{either 'text' or 'json' format could be specified.}
|
||||
\item{dump_format}{either 'text', 'json', or 'dot' (graphviz) format could be specified.
|
||||
|
||||
Format 'dot' for a single tree can be passed directly to packages that consume this format
|
||||
for graph visualization, such as function \code{\link[DiagrammeR:grViz]{DiagrammeR::grViz()}}}
|
||||
|
||||
\item{...}{currently not used}
|
||||
}
|
||||
@ -57,4 +60,8 @@ print(xgb.dump(bst, with_stats = TRUE))
|
||||
# print in JSON format:
|
||||
cat(xgb.dump(bst, with_stats = TRUE, dump_format='json'))
|
||||
|
||||
# plot first tree leveraging the 'dot' format
|
||||
if (requireNamespace('DiagrammeR', quietly = TRUE)) {
|
||||
DiagrammeR::grViz(xgb.dump(bst, dump_format = "dot")[[1L]])
|
||||
}
|
||||
}
|
||||
|
||||
@ -12,6 +12,7 @@ xgb.plot.tree(
|
||||
plot_height = NULL,
|
||||
render = TRUE,
|
||||
show_node_id = FALSE,
|
||||
style = c("R", "xgboost"),
|
||||
...
|
||||
)
|
||||
}
|
||||
@ -34,6 +35,22 @@ The values are passed to \code{\link[DiagrammeR:render_graph]{DiagrammeR::render
|
||||
|
||||
\item{show_node_id}{a logical flag for whether to show node id's in the graph.}
|
||||
|
||||
\item{style}{Style to use for the plot. Options are:\itemize{
|
||||
\item \code{"xgboost"}: will use the plot style defined in the core XGBoost library,
|
||||
which is shared between different interfaces through the 'dot' format. This
|
||||
style was not available before version 2.1.0 in R. It always plots the trees
|
||||
vertically (from top to bottom).
|
||||
\item \code{"R"}: will use the style defined from XGBoost's R interface, which predates
|
||||
the introducition of the standardized style from the core library. It might plot
|
||||
the trees horizontally (from left to right).
|
||||
}
|
||||
|
||||
Note that \code{style="xgboost"} is only supported when all of the following conditions are met:\itemize{
|
||||
\item Only a single tree is being plotted.
|
||||
\item Node IDs are not added to the graph.
|
||||
\item The graph is being returned as \code{htmlwidget} (\code{render=TRUE}).
|
||||
}}
|
||||
|
||||
\item{...}{currently not used.}
|
||||
}
|
||||
\value{
|
||||
@ -51,7 +68,16 @@ before rendering the graph with \code{\link[DiagrammeR:render_graph]{DiagrammeR:
|
||||
Read a tree model text dump and plot the model.
|
||||
}
|
||||
\details{
|
||||
The content of each node is visualized like this:
|
||||
When using \code{style="xgboost"}, the content of each node is visualized as follows:
|
||||
\itemize{
|
||||
\item For non-terminal nodes, it will display the split condition (number or name if
|
||||
available, and the condition that would decide to which node to go next).
|
||||
\item Those nodes will be connected to their children by arrows that indicate whether the
|
||||
branch corresponds to the condition being met or not being met.
|
||||
\item Terminal (leaf) nodes contain the margin to add when ending there.
|
||||
}
|
||||
|
||||
When using \code{style="R"}, the content of each node is visualized like this:
|
||||
\itemize{
|
||||
\item \emph{Feature name}.
|
||||
\item \emph{Cover:} The sum of second order gradients of training data.
|
||||
@ -83,8 +109,13 @@ bst <- xgboost(
|
||||
objective = "binary:logistic"
|
||||
)
|
||||
|
||||
# plot the first tree, using the style from xgboost's core library
|
||||
# (this plot should look identical to the ones generated from other
|
||||
# interfaces like the python package for xgboost)
|
||||
xgb.plot.tree(model = bst, trees = 1, style = "xgboost")
|
||||
|
||||
# plot all the trees
|
||||
xgb.plot.tree(model = bst)
|
||||
xgb.plot.tree(model = bst, trees = NULL)
|
||||
|
||||
# plot only the first tree and display the node ID:
|
||||
xgb.plot.tree(model = bst, trees = 0, show_node_id = TRUE)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user