136 lines
4.7 KiB
R
136 lines
4.7 KiB
R
# ggplot backend for the xgboost plotting facilities
|
|
|
|
|
|
#' @rdname xgb.plot.importance
|
|
#' @export
|
|
xgb.ggplot.importance <- function(importance_matrix = NULL, top_n = NULL, measure = NULL,
|
|
rel_to_first = FALSE, n_clusters = c(1:10), ...) {
|
|
|
|
importance_matrix <- xgb.plot.importance(importance_matrix, top_n = top_n, measure = measure,
|
|
rel_to_first = rel_to_first, plot = FALSE, ...)
|
|
if (!requireNamespace("ggplot2", quietly = TRUE)) {
|
|
stop("ggplot2 package is required", call. = FALSE)
|
|
}
|
|
if (!requireNamespace("Ckmeans.1d.dp", quietly = TRUE)) {
|
|
stop("Ckmeans.1d.dp package is required", call. = FALSE)
|
|
}
|
|
|
|
clusters <- suppressWarnings(
|
|
Ckmeans.1d.dp::Ckmeans.1d.dp(importance_matrix$Importance, n_clusters)
|
|
)
|
|
importance_matrix[, Cluster := as.character(clusters$cluster)]
|
|
|
|
plot <-
|
|
ggplot2::ggplot(importance_matrix,
|
|
ggplot2::aes(x = factor(Feature, levels = rev(Feature)), y = Importance, width = 0.05),
|
|
environment = environment()) +
|
|
ggplot2::geom_bar(ggplot2::aes(fill = Cluster), stat = "identity", position = "identity") +
|
|
ggplot2::coord_flip() +
|
|
ggplot2::xlab("Features") +
|
|
ggplot2::ggtitle("Feature importance") +
|
|
ggplot2::theme(plot.title = ggplot2::element_text(lineheight = .9, face = "bold"),
|
|
panel.grid.major.y = ggplot2::element_blank())
|
|
return(plot)
|
|
}
|
|
|
|
|
|
#' @rdname xgb.plot.deepness
|
|
#' @export
|
|
xgb.ggplot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med.depth", "med.weight")) {
|
|
|
|
if (!requireNamespace("ggplot2", quietly = TRUE))
|
|
stop("ggplot2 package is required for plotting the graph deepness.", call. = FALSE)
|
|
|
|
which <- match.arg(which)
|
|
|
|
dt_depths <- xgb.plot.deepness(model = model, plot = FALSE)
|
|
dt_summaries <- dt_depths[, .(.N, Cover = mean(Cover)), Depth]
|
|
setkey(dt_summaries, 'Depth')
|
|
|
|
if (which == "2x1") {
|
|
p1 <-
|
|
ggplot2::ggplot(dt_summaries) +
|
|
ggplot2::geom_bar(ggplot2::aes(x = Depth, y = N), stat = "Identity") +
|
|
ggplot2::xlab("") +
|
|
ggplot2::ylab("Number of leafs") +
|
|
ggplot2::ggtitle("Model complexity") +
|
|
ggplot2::theme(
|
|
plot.title = ggplot2::element_text(lineheight = 0.9, face = "bold"),
|
|
panel.grid.major.y = ggplot2::element_blank(),
|
|
axis.ticks = ggplot2::element_blank(),
|
|
axis.text.x = ggplot2::element_blank()
|
|
)
|
|
|
|
p2 <-
|
|
ggplot2::ggplot(dt_summaries) +
|
|
ggplot2::geom_bar(ggplot2::aes(x = Depth, y = Cover), stat = "Identity") +
|
|
ggplot2::xlab("Leaf depth") +
|
|
ggplot2::ylab("Weighted cover")
|
|
|
|
multiplot(p1, p2, cols = 1)
|
|
return(invisible(list(p1, p2)))
|
|
|
|
} else if (which == "max.depth") {
|
|
p <-
|
|
ggplot2::ggplot(dt_depths[, max(Depth), Tree]) +
|
|
ggplot2::geom_jitter(ggplot2::aes(x = Tree, y = V1),
|
|
height = 0.15, alpha=0.4, size=3, stroke=0) +
|
|
ggplot2::xlab("tree #") +
|
|
ggplot2::ylab("Max tree leaf depth")
|
|
return(p)
|
|
|
|
} else if (which == "med.depth") {
|
|
p <-
|
|
ggplot2::ggplot(dt_depths[, median(as.numeric(Depth)), Tree]) +
|
|
ggplot2::geom_jitter(ggplot2::aes(x = Tree, y = V1),
|
|
height = 0.15, alpha=0.4, size=3, stroke=0) +
|
|
ggplot2::xlab("tree #") +
|
|
ggplot2::ylab("Median tree leaf depth")
|
|
return(p)
|
|
|
|
} else if (which == "med.weight") {
|
|
p <-
|
|
ggplot2::ggplot(dt_depths[, median(abs(Weight)), Tree]) +
|
|
ggplot2::geom_point(ggplot2::aes(x = Tree, y = V1),
|
|
alpha=0.4, size=3, stroke=0) +
|
|
ggplot2::xlab("tree #") +
|
|
ggplot2::ylab("Median absolute leaf weight")
|
|
return(p)
|
|
}
|
|
}
|
|
|
|
# Plot multiple ggplot graph aligned by rows and columns.
|
|
# ... the plots
|
|
# cols number of columns
|
|
# internal utility function
|
|
multiplot <- function(..., cols = 1) {
|
|
plots <- list(...)
|
|
num_plots = length(plots)
|
|
|
|
layout <- matrix(seq(1, cols * ceiling(num_plots / cols)),
|
|
ncol = cols, nrow = ceiling(num_plots / cols))
|
|
|
|
if (num_plots == 1) {
|
|
print(plots[[1]])
|
|
} else {
|
|
grid::grid.newpage()
|
|
grid::pushViewport(grid::viewport(layout = grid::grid.layout(nrow(layout), ncol(layout))))
|
|
for (i in 1:num_plots) {
|
|
# Get the i,j matrix positions of the regions that contain this subplot
|
|
matchidx <- as.data.table(which(layout == i, arr.ind = TRUE))
|
|
|
|
print(
|
|
plots[[i]], vp = grid::viewport(
|
|
layout.pos.row = matchidx$row,
|
|
layout.pos.col = matchidx$col
|
|
)
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
globalVariables(c(
|
|
"Cluster", "ggplot", "aes", "geom_bar", "coord_flip", "xlab", "ylab", "ggtitle", "theme",
|
|
"element_blank", "element_text", "V1", "Weight"
|
|
))
|