@@ -2,48 +2,48 @@
|
||||
#'
|
||||
#' Visualizes distributions related to depth of tree leafs.
|
||||
#' \code{xgb.plot.deepness} uses base R graphics, while \code{xgb.ggplot.deepness} uses the ggplot backend.
|
||||
#'
|
||||
#'
|
||||
#' @param model either an \code{xgb.Booster} model generated by the \code{xgb.train} function
|
||||
#' or a data.table result of the \code{xgb.model.dt.tree} function.
|
||||
#' @param plot (base R barplot) whether a barplot should be produced.
|
||||
#' @param plot (base R barplot) whether a barplot should be produced.
|
||||
#' If FALSE, only a data.table is returned.
|
||||
#' @param which which distribution to plot (see details).
|
||||
#' @param ... other parameters passed to \code{barplot} or \code{plot}.
|
||||
#'
|
||||
#'
|
||||
#' @details
|
||||
#'
|
||||
#'
|
||||
#' When \code{which="2x1"}, two distributions with respect to the leaf depth
|
||||
#' are plotted on top of each other:
|
||||
#' \itemize{
|
||||
#' \item the distribution of the number of leafs in a tree model at a certain depth;
|
||||
#' \item the distribution of average weighted number of observations ("cover")
|
||||
#' \item the distribution of average weighted number of observations ("cover")
|
||||
#' ending up in leafs at certain depth.
|
||||
#' }
|
||||
#' Those could be helpful in determining sensible ranges of the \code{max_depth}
|
||||
#' Those could be helpful in determining sensible ranges of the \code{max_depth}
|
||||
#' and \code{min_child_weight} parameters.
|
||||
#'
|
||||
#'
|
||||
#' When \code{which="max.depth"} or \code{which="med.depth"}, plots of either maximum or median depth
|
||||
#' per tree with respect to tree number are created. And \code{which="med.weight"} allows to see how
|
||||
#' a tree's median absolute leaf weight changes through the iterations.
|
||||
#'
|
||||
#' This function was inspired by the blog post
|
||||
#' \url{https://github.com/aysent/random-forest-leaf-visualization}.
|
||||
#'
|
||||
#'
|
||||
#' @return
|
||||
#'
|
||||
#'
|
||||
#' Other than producing plots (when \code{plot=TRUE}), the \code{xgb.plot.deepness} function
|
||||
#' silently returns a processed data.table where each row corresponds to a terminal leaf in a tree model,
|
||||
#' and contains information about leaf's depth, cover, and weight (which is used in calculating predictions).
|
||||
#'
|
||||
#'
|
||||
#' The \code{xgb.ggplot.deepness} silently returns either a list of two ggplot graphs when \code{which="2x1"}
|
||||
#' or a single ggplot graph for the other \code{which} options.
|
||||
#'
|
||||
#' @seealso
|
||||
#'
|
||||
#' @seealso
|
||||
#'
|
||||
#' \code{\link{xgb.train}}, \code{\link{xgb.model.dt.tree}}.
|
||||
#'
|
||||
#'
|
||||
#' @examples
|
||||
#'
|
||||
#'
|
||||
#' data(agaricus.train, package='xgboost')
|
||||
#'
|
||||
#' # Change max_depth to a higher number to get a more significant result
|
||||
@@ -53,16 +53,16 @@
|
||||
#'
|
||||
#' xgb.plot.deepness(bst)
|
||||
#' xgb.ggplot.deepness(bst)
|
||||
#'
|
||||
#'
|
||||
#' xgb.plot.deepness(bst, which='max.depth', pch=16, col=rgb(0,0,1,0.3), cex=2)
|
||||
#'
|
||||
#'
|
||||
#' xgb.plot.deepness(bst, which='med.weight', pch=16, col=rgb(0,0,1,0.3), cex=2)
|
||||
#'
|
||||
#' @rdname xgb.plot.deepness
|
||||
#' @export
|
||||
xgb.plot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med.depth", "med.weight"),
|
||||
plot = TRUE, ...) {
|
||||
|
||||
|
||||
if (!(inherits(model, "xgb.Booster") || is.data.table(model)))
|
||||
stop("model: Has to be either an xgb.Booster model generaged by the xgb.train function\n",
|
||||
"or a data.table result of the xgb.importance function")
|
||||
@@ -71,32 +71,32 @@ xgb.plot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med.d
|
||||
stop("igraph package is required for plotting the graph deepness.", call. = FALSE)
|
||||
|
||||
which <- match.arg(which)
|
||||
|
||||
|
||||
dt_tree <- model
|
||||
if (inherits(model, "xgb.Booster"))
|
||||
dt_tree <- xgb.model.dt.tree(model = model)
|
||||
|
||||
|
||||
if (!all(c("Feature", "Tree", "ID", "Yes", "No", "Cover") %in% colnames(dt_tree)))
|
||||
stop("Model tree columns are not as expected!\n",
|
||||
" Note that this function works only for tree models.")
|
||||
|
||||
|
||||
dt_depths <- merge(get.leaf.depth(dt_tree), dt_tree[, .(ID, Cover, Weight = Quality)], by = "ID")
|
||||
setkeyv(dt_depths, c("Tree", "ID"))
|
||||
# count by depth levels, and also calculate average cover at a depth
|
||||
dt_summaries <- dt_depths[, .(.N, Cover = mean(Cover)), Depth]
|
||||
setkey(dt_summaries, "Depth")
|
||||
|
||||
|
||||
if (plot) {
|
||||
if (which == "2x1") {
|
||||
op <- par(no.readonly = TRUE)
|
||||
par(mfrow = c(2,1),
|
||||
oma = c(3,1,3,1) + 0.1,
|
||||
mar = c(1,4,1,0) + 0.1)
|
||||
par(mfrow = c(2, 1),
|
||||
oma = c(3, 1, 3, 1) + 0.1,
|
||||
mar = c(1, 4, 1, 0) + 0.1)
|
||||
|
||||
dt_summaries[, barplot(N, border = NA, ylab = 'Number of leafs', ...)]
|
||||
|
||||
dt_summaries[, barplot(Cover, border = NA, ylab = "Weighted cover", names.arg = Depth, ...)]
|
||||
|
||||
|
||||
title("Model complexity", xlab = "Leaf depth", outer = TRUE, line = 1)
|
||||
par(op)
|
||||
} else if (which == "max.depth") {
|
||||
@@ -123,14 +123,14 @@ get.leaf.depth <- function(dt_tree) {
|
||||
dt_tree[Feature != "Leaf", .(ID, To = No, Tree)]
|
||||
))
|
||||
# whether "To" is a leaf:
|
||||
dt_edges <-
|
||||
dt_edges <-
|
||||
merge(dt_edges,
|
||||
dt_tree[Feature == "Leaf", .(ID, Leaf = TRUE)],
|
||||
all.x = TRUE, by.x = "To", by.y = "ID")
|
||||
dt_edges[is.na(Leaf), Leaf := FALSE]
|
||||
|
||||
dt_edges[, {
|
||||
graph <- igraph::graph_from_data_frame(.SD[,.(ID, To)])
|
||||
graph <- igraph::graph_from_data_frame(.SD[, .(ID, To)])
|
||||
# min(ID) in a tree is a root node
|
||||
paths_tmp <- igraph::shortest_paths(graph, from = min(ID), to = To[Leaf == TRUE])
|
||||
# list of paths to each leaf in a tree
|
||||
|
||||
Reference in New Issue
Block a user