[R] Enable 'dot' dump format (#9930)
This commit is contained in:
@@ -13,7 +13,10 @@
|
||||
#' When this option is on, the model dump contains two additional values:
|
||||
#' gain is the approximate loss function gain we get in each split;
|
||||
#' cover is the sum of second order gradient in each node.
|
||||
#' @param dump_format either 'text' or 'json' format could be specified.
|
||||
#' @param dump_format either 'text', 'json', or 'dot' (graphviz) format could be specified.
|
||||
#'
|
||||
#' Format 'dot' for a single tree can be passed directly to packages that consume this format
|
||||
#' for graph visualization, such as function [DiagrammeR::grViz()]
|
||||
#' @param ... currently not used
|
||||
#'
|
||||
#' @return
|
||||
@@ -37,9 +40,13 @@
|
||||
#' # print in JSON format:
|
||||
#' cat(xgb.dump(bst, with_stats = TRUE, dump_format='json'))
|
||||
#'
|
||||
#' # plot first tree leveraging the 'dot' format
|
||||
#' if (requireNamespace('DiagrammeR', quietly = TRUE)) {
|
||||
#' DiagrammeR::grViz(xgb.dump(bst, dump_format = "dot")[[1L]])
|
||||
#' }
|
||||
#' @export
|
||||
xgb.dump <- function(model, fname = NULL, fmap = "", with_stats = FALSE,
|
||||
dump_format = c("text", "json"), ...) {
|
||||
dump_format = c("text", "json", "dot"), ...) {
|
||||
check.deprecation(...)
|
||||
dump_format <- match.arg(dump_format)
|
||||
if (!inherits(model, "xgb.Booster"))
|
||||
@@ -52,6 +59,9 @@ xgb.dump <- function(model, fname = NULL, fmap = "", with_stats = FALSE,
|
||||
model <- xgb.Booster.complete(model)
|
||||
model_dump <- .Call(XGBoosterDumpModel_R, model$handle, NVL(fmap, "")[1], as.integer(with_stats),
|
||||
as.character(dump_format))
|
||||
if (dump_format == "dot") {
|
||||
return(sapply(model_dump, function(x) gsub("^booster\\[\\d+\\]\\n", "\\1", x)))
|
||||
}
|
||||
|
||||
if (is.null(fname))
|
||||
model_dump <- gsub('\t', '', model_dump, fixed = TRUE)
|
||||
|
||||
@@ -14,11 +14,33 @@
|
||||
#' The values are passed to [DiagrammeR::render_graph()].
|
||||
#' @param render Should the graph be rendered or not? The default is `TRUE`.
|
||||
#' @param show_node_id a logical flag for whether to show node id's in the graph.
|
||||
#' @param style Style to use for the plot. Options are:\itemize{
|
||||
#' \item `"xgboost"`: will use the plot style defined in the core XGBoost library,
|
||||
#' which is shared between different interfaces through the 'dot' format. This
|
||||
#' style was not available before version 2.1.0 in R. It always plots the trees
|
||||
#' vertically (from top to bottom).
|
||||
#' \item `"R"`: will use the style defined from XGBoost's R interface, which predates
|
||||
#' the introducition of the standardized style from the core library. It might plot
|
||||
#' the trees horizontally (from left to right).
|
||||
#' }
|
||||
#'
|
||||
#' Note that `style="xgboost"` is only supported when all of the following conditions are met:\itemize{
|
||||
#' \item Only a single tree is being plotted.
|
||||
#' \item Node IDs are not added to the graph.
|
||||
#' \item The graph is being returned as `htmlwidget` (`render=TRUE`).
|
||||
#' }
|
||||
#' @param ... currently not used.
|
||||
#'
|
||||
#' @details
|
||||
#'
|
||||
#' The content of each node is visualized like this:
|
||||
#' When using `style="xgboost"`, the content of each node is visualized as follows:
|
||||
#' - For non-terminal nodes, it will display the split condition (number or name if
|
||||
#' available, and the condition that would decide to which node to go next).
|
||||
#' - Those nodes will be connected to their children by arrows that indicate whether the
|
||||
#' branch corresponds to the condition being met or not being met.
|
||||
#' - Terminal (leaf) nodes contain the margin to add when ending there.
|
||||
#'
|
||||
#' When using `style="R"`, the content of each node is visualized like this:
|
||||
#' - *Feature name*.
|
||||
#' - *Cover:* The sum of second order gradients of training data.
|
||||
#' For the squared loss, this simply corresponds to the number of instances in the node.
|
||||
@@ -57,8 +79,13 @@
|
||||
#' objective = "binary:logistic"
|
||||
#' )
|
||||
#'
|
||||
#' # plot the first tree, using the style from xgboost's core library
|
||||
#' # (this plot should look identical to the ones generated from other
|
||||
#' # interfaces like the python package for xgboost)
|
||||
#' xgb.plot.tree(model = bst, trees = 1, style = "xgboost")
|
||||
#'
|
||||
#' # plot all the trees
|
||||
#' xgb.plot.tree(model = bst)
|
||||
#' xgb.plot.tree(model = bst, trees = NULL)
|
||||
#'
|
||||
#' # plot only the first tree and display the node ID:
|
||||
#' xgb.plot.tree(model = bst, trees = 0, show_node_id = TRUE)
|
||||
@@ -77,7 +104,7 @@
|
||||
#'
|
||||
#' @export
|
||||
xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot_width = NULL, plot_height = NULL,
|
||||
render = TRUE, show_node_id = FALSE, ...) {
|
||||
render = TRUE, show_node_id = FALSE, style = c("R", "xgboost"), ...) {
|
||||
check.deprecation(...)
|
||||
if (!inherits(model, "xgb.Booster")) {
|
||||
stop("model: Has to be an object of class xgb.Booster")
|
||||
@@ -87,6 +114,22 @@ xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot
|
||||
stop("DiagrammeR package is required for xgb.plot.tree", call. = FALSE)
|
||||
}
|
||||
|
||||
style <- as.character(head(style, 1L))
|
||||
stopifnot(style %in% c("R", "xgboost"))
|
||||
if (style == "xgboost") {
|
||||
if (NROW(trees) != 1L || !render || show_node_id) {
|
||||
stop("style='xgboost' is only supported for single, rendered tree, without node IDs.")
|
||||
}
|
||||
if (!is.null(feature_names)) {
|
||||
stop(
|
||||
"style='xgboost' cannot override 'feature_names'. Will automatically take them from the model."
|
||||
)
|
||||
}
|
||||
|
||||
txt <- xgb.dump(model, dump_format = "dot")
|
||||
return(DiagrammeR::grViz(txt[[trees + 1]], width = plot_width, height = plot_height))
|
||||
}
|
||||
|
||||
dt <- xgb.model.dt.tree(feature_names = feature_names, model = model, trees = trees)
|
||||
|
||||
dt[, label := paste0(Feature, "\nCover: ", Cover, ifelse(Feature == "Leaf", "\nValue: ", "\nGain: "), Quality)]
|
||||
|
||||
Reference in New Issue
Block a user