[R-package] GPL2 dependency reduction and some fixes (#1401)
* [R] do not remove zero coefficients from gblinear dump * [R] switch from stringr to stringi * fix #1399 * [R] separate ggplot backend, add base r graphics, cleanup, more plots, tests * add missing include in amalgamation - fixes building R package in linux * add forgotten file * [R] fix DESCRIPTION * [R] fix travis check issue and some cleanup
This commit is contained in:
committed by
Tong He
parent
f6423056c0
commit
d5c143367d
135
R-package/R/xgb.ggplot.R
Normal file
135
R-package/R/xgb.ggplot.R
Normal file
@@ -0,0 +1,135 @@
|
||||
# ggplot backend for the xgboost plotting facilities
|
||||
|
||||
|
||||
#' @rdname xgb.plot.importance
|
||||
#' @export
|
||||
xgb.ggplot.importance <- function(importance_matrix = NULL, top_n = NULL, measure = NULL,
|
||||
rel_to_first = FALSE, n_clusters = c(1:10), ...) {
|
||||
|
||||
importance_matrix <- xgb.plot.importance(importance_matrix, top_n = top_n, measure = measure,
|
||||
rel_to_first = rel_to_first, plot = FALSE, ...)
|
||||
if (!requireNamespace("ggplot2", quietly = TRUE)) {
|
||||
stop("ggplot2 package is required", call. = FALSE)
|
||||
}
|
||||
if (!requireNamespace("Ckmeans.1d.dp", quietly = TRUE)) {
|
||||
stop("Ckmeans.1d.dp package is required", call. = FALSE)
|
||||
}
|
||||
|
||||
clusters <- suppressWarnings(
|
||||
Ckmeans.1d.dp::Ckmeans.1d.dp(importance_matrix$Importance, n_clusters)
|
||||
)
|
||||
importance_matrix[, Cluster := as.character(clusters$cluster)]
|
||||
|
||||
plot <-
|
||||
ggplot2::ggplot(importance_matrix,
|
||||
ggplot2::aes(x = factor(Feature, levels = rev(Feature)), y = Importance, width = 0.05),
|
||||
environment = environment()) +
|
||||
ggplot2::geom_bar(ggplot2::aes(fill = Cluster), stat = "identity", position = "identity") +
|
||||
ggplot2::coord_flip() +
|
||||
ggplot2::xlab("Features") +
|
||||
ggplot2::ggtitle("Feature importance") +
|
||||
ggplot2::theme(plot.title = ggplot2::element_text(lineheight = .9, face = "bold"),
|
||||
panel.grid.major.y = ggplot2::element_blank())
|
||||
return(plot)
|
||||
}
|
||||
|
||||
|
||||
#' @rdname xgb.plot.deepness
|
||||
#' @export
|
||||
xgb.ggplot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med.depth", "med.weight")) {
|
||||
|
||||
if (!requireNamespace("ggplot2", quietly = TRUE))
|
||||
stop("ggplot2 package is required for plotting the graph deepness.", call. = FALSE)
|
||||
|
||||
which <- match.arg(which)
|
||||
|
||||
dt_depths <- xgb.plot.deepness(model = model, plot = FALSE)
|
||||
dt_summaries <- dt_depths[, .(.N, Cover = mean(Cover)), Depth]
|
||||
setkey(dt_summaries, 'Depth')
|
||||
|
||||
if (which == "2x1") {
|
||||
p1 <-
|
||||
ggplot2::ggplot(dt_summaries) +
|
||||
ggplot2::geom_bar(ggplot2::aes(x = Depth, y = N), stat = "Identity") +
|
||||
ggplot2::xlab("") +
|
||||
ggplot2::ylab("Number of leafs") +
|
||||
ggplot2::ggtitle("Model complexity") +
|
||||
ggplot2::theme(
|
||||
plot.title = ggplot2::element_text(lineheight = 0.9, face = "bold"),
|
||||
panel.grid.major.y = ggplot2::element_blank(),
|
||||
axis.ticks = ggplot2::element_blank(),
|
||||
axis.text.x = ggplot2::element_blank()
|
||||
)
|
||||
|
||||
p2 <-
|
||||
ggplot2::ggplot(dt_summaries) +
|
||||
ggplot2::geom_bar(ggplot2::aes(x = Depth, y = Cover), stat = "Identity") +
|
||||
ggplot2::xlab("Leaf depth") +
|
||||
ggplot2::ylab("Weighted cover")
|
||||
|
||||
multiplot(p1, p2, cols = 1)
|
||||
return(invisible(list(p1, p2)))
|
||||
|
||||
} else if (which == "max.depth") {
|
||||
p <-
|
||||
ggplot2::ggplot(dt_depths[, max(Depth), Tree]) +
|
||||
ggplot2::geom_jitter(ggplot2::aes(x = Tree, y = V1),
|
||||
height = 0.15, alpha=0.4, size=3, stroke=0) +
|
||||
ggplot2::xlab("tree #") +
|
||||
ggplot2::ylab("Max tree leaf depth")
|
||||
return(p)
|
||||
|
||||
} else if (which == "med.depth") {
|
||||
p <-
|
||||
ggplot2::ggplot(dt_depths[, median(as.numeric(Depth)), Tree]) +
|
||||
ggplot2::geom_jitter(ggplot2::aes(x = Tree, y = V1),
|
||||
height = 0.15, alpha=0.4, size=3, stroke=0) +
|
||||
ggplot2::xlab("tree #") +
|
||||
ggplot2::ylab("Median tree leaf depth")
|
||||
return(p)
|
||||
|
||||
} else if (which == "med.weight") {
|
||||
p <-
|
||||
ggplot2::ggplot(dt_depths[, median(abs(Weight)), Tree]) +
|
||||
ggplot2::geom_point(ggplot2::aes(x = Tree, y = V1),
|
||||
alpha=0.4, size=3, stroke=0) +
|
||||
ggplot2::xlab("tree #") +
|
||||
ggplot2::ylab("Median absolute leaf weight")
|
||||
return(p)
|
||||
}
|
||||
}
|
||||
|
||||
# Plot multiple ggplot graph aligned by rows and columns.
|
||||
# ... the plots
|
||||
# cols number of columns
|
||||
# internal utility function
|
||||
multiplot <- function(..., cols = 1) {
|
||||
plots <- list(...)
|
||||
num_plots = length(plots)
|
||||
|
||||
layout <- matrix(seq(1, cols * ceiling(num_plots / cols)),
|
||||
ncol = cols, nrow = ceiling(num_plots / cols))
|
||||
|
||||
if (num_plots == 1) {
|
||||
print(plots[[1]])
|
||||
} else {
|
||||
grid::grid.newpage()
|
||||
grid::pushViewport(grid::viewport(layout = grid::grid.layout(nrow(layout), ncol(layout))))
|
||||
for (i in 1:num_plots) {
|
||||
# Get the i,j matrix positions of the regions that contain this subplot
|
||||
matchidx <- as.data.table(which(layout == i, arr.ind = TRUE))
|
||||
|
||||
print(
|
||||
plots[[i]], vp = grid::viewport(
|
||||
layout.pos.row = matchidx$row,
|
||||
layout.pos.col = matchidx$col
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
globalVariables(c(
|
||||
"Cluster", "ggplot", "aes", "geom_bar", "coord_flip", "xlab", "ylab", "ggtitle", "theme",
|
||||
"element_blank", "element_text"
|
||||
))
|
||||
Reference in New Issue
Block a user