[R] Enable 'dot' dump format (#9930)
This commit is contained in:
parent
ef8bdaa047
commit
e40c4260ed
@ -13,7 +13,10 @@
|
|||||||
#' When this option is on, the model dump contains two additional values:
|
#' When this option is on, the model dump contains two additional values:
|
||||||
#' gain is the approximate loss function gain we get in each split;
|
#' gain is the approximate loss function gain we get in each split;
|
||||||
#' cover is the sum of second order gradient in each node.
|
#' cover is the sum of second order gradient in each node.
|
||||||
#' @param dump_format either 'text' or 'json' format could be specified.
|
#' @param dump_format either 'text', 'json', or 'dot' (graphviz) format could be specified.
|
||||||
|
#'
|
||||||
|
#' Format 'dot' for a single tree can be passed directly to packages that consume this format
|
||||||
|
#' for graph visualization, such as function [DiagrammeR::grViz()]
|
||||||
#' @param ... currently not used
|
#' @param ... currently not used
|
||||||
#'
|
#'
|
||||||
#' @return
|
#' @return
|
||||||
@ -37,9 +40,13 @@
|
|||||||
#' # print in JSON format:
|
#' # print in JSON format:
|
||||||
#' cat(xgb.dump(bst, with_stats = TRUE, dump_format='json'))
|
#' cat(xgb.dump(bst, with_stats = TRUE, dump_format='json'))
|
||||||
#'
|
#'
|
||||||
|
#' # plot first tree leveraging the 'dot' format
|
||||||
|
#' if (requireNamespace('DiagrammeR', quietly = TRUE)) {
|
||||||
|
#' DiagrammeR::grViz(xgb.dump(bst, dump_format = "dot")[[1L]])
|
||||||
|
#' }
|
||||||
#' @export
|
#' @export
|
||||||
xgb.dump <- function(model, fname = NULL, fmap = "", with_stats = FALSE,
|
xgb.dump <- function(model, fname = NULL, fmap = "", with_stats = FALSE,
|
||||||
dump_format = c("text", "json"), ...) {
|
dump_format = c("text", "json", "dot"), ...) {
|
||||||
check.deprecation(...)
|
check.deprecation(...)
|
||||||
dump_format <- match.arg(dump_format)
|
dump_format <- match.arg(dump_format)
|
||||||
if (!inherits(model, "xgb.Booster"))
|
if (!inherits(model, "xgb.Booster"))
|
||||||
@ -52,6 +59,9 @@ xgb.dump <- function(model, fname = NULL, fmap = "", with_stats = FALSE,
|
|||||||
model <- xgb.Booster.complete(model)
|
model <- xgb.Booster.complete(model)
|
||||||
model_dump <- .Call(XGBoosterDumpModel_R, model$handle, NVL(fmap, "")[1], as.integer(with_stats),
|
model_dump <- .Call(XGBoosterDumpModel_R, model$handle, NVL(fmap, "")[1], as.integer(with_stats),
|
||||||
as.character(dump_format))
|
as.character(dump_format))
|
||||||
|
if (dump_format == "dot") {
|
||||||
|
return(sapply(model_dump, function(x) gsub("^booster\\[\\d+\\]\\n", "\\1", x)))
|
||||||
|
}
|
||||||
|
|
||||||
if (is.null(fname))
|
if (is.null(fname))
|
||||||
model_dump <- gsub('\t', '', model_dump, fixed = TRUE)
|
model_dump <- gsub('\t', '', model_dump, fixed = TRUE)
|
||||||
|
|||||||
@ -14,11 +14,33 @@
|
|||||||
#' The values are passed to [DiagrammeR::render_graph()].
|
#' The values are passed to [DiagrammeR::render_graph()].
|
||||||
#' @param render Should the graph be rendered or not? The default is `TRUE`.
|
#' @param render Should the graph be rendered or not? The default is `TRUE`.
|
||||||
#' @param show_node_id a logical flag for whether to show node id's in the graph.
|
#' @param show_node_id a logical flag for whether to show node id's in the graph.
|
||||||
|
#' @param style Style to use for the plot. Options are:\itemize{
|
||||||
|
#' \item `"xgboost"`: will use the plot style defined in the core XGBoost library,
|
||||||
|
#' which is shared between different interfaces through the 'dot' format. This
|
||||||
|
#' style was not available before version 2.1.0 in R. It always plots the trees
|
||||||
|
#' vertically (from top to bottom).
|
||||||
|
#' \item `"R"`: will use the style defined from XGBoost's R interface, which predates
|
||||||
|
#' the introducition of the standardized style from the core library. It might plot
|
||||||
|
#' the trees horizontally (from left to right).
|
||||||
|
#' }
|
||||||
|
#'
|
||||||
|
#' Note that `style="xgboost"` is only supported when all of the following conditions are met:\itemize{
|
||||||
|
#' \item Only a single tree is being plotted.
|
||||||
|
#' \item Node IDs are not added to the graph.
|
||||||
|
#' \item The graph is being returned as `htmlwidget` (`render=TRUE`).
|
||||||
|
#' }
|
||||||
#' @param ... currently not used.
|
#' @param ... currently not used.
|
||||||
#'
|
#'
|
||||||
#' @details
|
#' @details
|
||||||
#'
|
#'
|
||||||
#' The content of each node is visualized like this:
|
#' When using `style="xgboost"`, the content of each node is visualized as follows:
|
||||||
|
#' - For non-terminal nodes, it will display the split condition (number or name if
|
||||||
|
#' available, and the condition that would decide to which node to go next).
|
||||||
|
#' - Those nodes will be connected to their children by arrows that indicate whether the
|
||||||
|
#' branch corresponds to the condition being met or not being met.
|
||||||
|
#' - Terminal (leaf) nodes contain the margin to add when ending there.
|
||||||
|
#'
|
||||||
|
#' When using `style="R"`, the content of each node is visualized like this:
|
||||||
#' - *Feature name*.
|
#' - *Feature name*.
|
||||||
#' - *Cover:* The sum of second order gradients of training data.
|
#' - *Cover:* The sum of second order gradients of training data.
|
||||||
#' For the squared loss, this simply corresponds to the number of instances in the node.
|
#' For the squared loss, this simply corresponds to the number of instances in the node.
|
||||||
@ -57,8 +79,13 @@
|
|||||||
#' objective = "binary:logistic"
|
#' objective = "binary:logistic"
|
||||||
#' )
|
#' )
|
||||||
#'
|
#'
|
||||||
|
#' # plot the first tree, using the style from xgboost's core library
|
||||||
|
#' # (this plot should look identical to the ones generated from other
|
||||||
|
#' # interfaces like the python package for xgboost)
|
||||||
|
#' xgb.plot.tree(model = bst, trees = 1, style = "xgboost")
|
||||||
|
#'
|
||||||
#' # plot all the trees
|
#' # plot all the trees
|
||||||
#' xgb.plot.tree(model = bst)
|
#' xgb.plot.tree(model = bst, trees = NULL)
|
||||||
#'
|
#'
|
||||||
#' # plot only the first tree and display the node ID:
|
#' # plot only the first tree and display the node ID:
|
||||||
#' xgb.plot.tree(model = bst, trees = 0, show_node_id = TRUE)
|
#' xgb.plot.tree(model = bst, trees = 0, show_node_id = TRUE)
|
||||||
@ -77,7 +104,7 @@
|
|||||||
#'
|
#'
|
||||||
#' @export
|
#' @export
|
||||||
xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot_width = NULL, plot_height = NULL,
|
xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot_width = NULL, plot_height = NULL,
|
||||||
render = TRUE, show_node_id = FALSE, ...) {
|
render = TRUE, show_node_id = FALSE, style = c("R", "xgboost"), ...) {
|
||||||
check.deprecation(...)
|
check.deprecation(...)
|
||||||
if (!inherits(model, "xgb.Booster")) {
|
if (!inherits(model, "xgb.Booster")) {
|
||||||
stop("model: Has to be an object of class xgb.Booster")
|
stop("model: Has to be an object of class xgb.Booster")
|
||||||
@ -87,6 +114,22 @@ xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot
|
|||||||
stop("DiagrammeR package is required for xgb.plot.tree", call. = FALSE)
|
stop("DiagrammeR package is required for xgb.plot.tree", call. = FALSE)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
style <- as.character(head(style, 1L))
|
||||||
|
stopifnot(style %in% c("R", "xgboost"))
|
||||||
|
if (style == "xgboost") {
|
||||||
|
if (NROW(trees) != 1L || !render || show_node_id) {
|
||||||
|
stop("style='xgboost' is only supported for single, rendered tree, without node IDs.")
|
||||||
|
}
|
||||||
|
if (!is.null(feature_names)) {
|
||||||
|
stop(
|
||||||
|
"style='xgboost' cannot override 'feature_names'. Will automatically take them from the model."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
txt <- xgb.dump(model, dump_format = "dot")
|
||||||
|
return(DiagrammeR::grViz(txt[[trees + 1]], width = plot_width, height = plot_height))
|
||||||
|
}
|
||||||
|
|
||||||
dt <- xgb.model.dt.tree(feature_names = feature_names, model = model, trees = trees)
|
dt <- xgb.model.dt.tree(feature_names = feature_names, model = model, trees = trees)
|
||||||
|
|
||||||
dt[, label := paste0(Feature, "\nCover: ", Cover, ifelse(Feature == "Leaf", "\nValue: ", "\nGain: "), Quality)]
|
dt[, label := paste0(Feature, "\nCover: ", Cover, ifelse(Feature == "Leaf", "\nValue: ", "\nGain: "), Quality)]
|
||||||
|
|||||||
@ -9,7 +9,7 @@ xgb.dump(
|
|||||||
fname = NULL,
|
fname = NULL,
|
||||||
fmap = "",
|
fmap = "",
|
||||||
with_stats = FALSE,
|
with_stats = FALSE,
|
||||||
dump_format = c("text", "json"),
|
dump_format = c("text", "json", "dot"),
|
||||||
...
|
...
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -29,7 +29,10 @@ When this option is on, the model dump contains two additional values:
|
|||||||
gain is the approximate loss function gain we get in each split;
|
gain is the approximate loss function gain we get in each split;
|
||||||
cover is the sum of second order gradient in each node.}
|
cover is the sum of second order gradient in each node.}
|
||||||
|
|
||||||
\item{dump_format}{either 'text' or 'json' format could be specified.}
|
\item{dump_format}{either 'text', 'json', or 'dot' (graphviz) format could be specified.
|
||||||
|
|
||||||
|
Format 'dot' for a single tree can be passed directly to packages that consume this format
|
||||||
|
for graph visualization, such as function \code{\link[DiagrammeR:grViz]{DiagrammeR::grViz()}}}
|
||||||
|
|
||||||
\item{...}{currently not used}
|
\item{...}{currently not used}
|
||||||
}
|
}
|
||||||
@ -57,4 +60,8 @@ print(xgb.dump(bst, with_stats = TRUE))
|
|||||||
# print in JSON format:
|
# print in JSON format:
|
||||||
cat(xgb.dump(bst, with_stats = TRUE, dump_format='json'))
|
cat(xgb.dump(bst, with_stats = TRUE, dump_format='json'))
|
||||||
|
|
||||||
|
# plot first tree leveraging the 'dot' format
|
||||||
|
if (requireNamespace('DiagrammeR', quietly = TRUE)) {
|
||||||
|
DiagrammeR::grViz(xgb.dump(bst, dump_format = "dot")[[1L]])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -12,6 +12,7 @@ xgb.plot.tree(
|
|||||||
plot_height = NULL,
|
plot_height = NULL,
|
||||||
render = TRUE,
|
render = TRUE,
|
||||||
show_node_id = FALSE,
|
show_node_id = FALSE,
|
||||||
|
style = c("R", "xgboost"),
|
||||||
...
|
...
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -34,6 +35,22 @@ The values are passed to \code{\link[DiagrammeR:render_graph]{DiagrammeR::render
|
|||||||
|
|
||||||
\item{show_node_id}{a logical flag for whether to show node id's in the graph.}
|
\item{show_node_id}{a logical flag for whether to show node id's in the graph.}
|
||||||
|
|
||||||
|
\item{style}{Style to use for the plot. Options are:\itemize{
|
||||||
|
\item \code{"xgboost"}: will use the plot style defined in the core XGBoost library,
|
||||||
|
which is shared between different interfaces through the 'dot' format. This
|
||||||
|
style was not available before version 2.1.0 in R. It always plots the trees
|
||||||
|
vertically (from top to bottom).
|
||||||
|
\item \code{"R"}: will use the style defined from XGBoost's R interface, which predates
|
||||||
|
the introducition of the standardized style from the core library. It might plot
|
||||||
|
the trees horizontally (from left to right).
|
||||||
|
}
|
||||||
|
|
||||||
|
Note that \code{style="xgboost"} is only supported when all of the following conditions are met:\itemize{
|
||||||
|
\item Only a single tree is being plotted.
|
||||||
|
\item Node IDs are not added to the graph.
|
||||||
|
\item The graph is being returned as \code{htmlwidget} (\code{render=TRUE}).
|
||||||
|
}}
|
||||||
|
|
||||||
\item{...}{currently not used.}
|
\item{...}{currently not used.}
|
||||||
}
|
}
|
||||||
\value{
|
\value{
|
||||||
@ -51,7 +68,16 @@ before rendering the graph with \code{\link[DiagrammeR:render_graph]{DiagrammeR:
|
|||||||
Read a tree model text dump and plot the model.
|
Read a tree model text dump and plot the model.
|
||||||
}
|
}
|
||||||
\details{
|
\details{
|
||||||
The content of each node is visualized like this:
|
When using \code{style="xgboost"}, the content of each node is visualized as follows:
|
||||||
|
\itemize{
|
||||||
|
\item For non-terminal nodes, it will display the split condition (number or name if
|
||||||
|
available, and the condition that would decide to which node to go next).
|
||||||
|
\item Those nodes will be connected to their children by arrows that indicate whether the
|
||||||
|
branch corresponds to the condition being met or not being met.
|
||||||
|
\item Terminal (leaf) nodes contain the margin to add when ending there.
|
||||||
|
}
|
||||||
|
|
||||||
|
When using \code{style="R"}, the content of each node is visualized like this:
|
||||||
\itemize{
|
\itemize{
|
||||||
\item \emph{Feature name}.
|
\item \emph{Feature name}.
|
||||||
\item \emph{Cover:} The sum of second order gradients of training data.
|
\item \emph{Cover:} The sum of second order gradients of training data.
|
||||||
@ -83,8 +109,13 @@ bst <- xgboost(
|
|||||||
objective = "binary:logistic"
|
objective = "binary:logistic"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# plot the first tree, using the style from xgboost's core library
|
||||||
|
# (this plot should look identical to the ones generated from other
|
||||||
|
# interfaces like the python package for xgboost)
|
||||||
|
xgb.plot.tree(model = bst, trees = 1, style = "xgboost")
|
||||||
|
|
||||||
# plot all the trees
|
# plot all the trees
|
||||||
xgb.plot.tree(model = bst)
|
xgb.plot.tree(model = bst, trees = NULL)
|
||||||
|
|
||||||
# plot only the first tree and display the node ID:
|
# plot only the first tree and display the node ID:
|
||||||
xgb.plot.tree(model = bst, trees = 0, show_node_id = TRUE)
|
xgb.plot.tree(model = bst, trees = 0, show_node_id = TRUE)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user