215 lines
7.6 KiB
R
215 lines
7.6 KiB
R
# ggplot backend for the xgboost plotting facilities
|
|
|
|
|
|
#' @rdname xgb.plot.importance
|
|
#' @export
|
|
xgb.ggplot.importance <- function(importance_matrix = NULL, top_n = NULL, measure = NULL,
|
|
rel_to_first = FALSE, n_clusters = seq_len(10), ...) {
|
|
|
|
importance_matrix <- xgb.plot.importance(importance_matrix, top_n = top_n, measure = measure,
|
|
rel_to_first = rel_to_first, plot = FALSE, ...)
|
|
if (!requireNamespace("ggplot2", quietly = TRUE)) {
|
|
stop("ggplot2 package is required", call. = FALSE)
|
|
}
|
|
if (!requireNamespace("Ckmeans.1d.dp", quietly = TRUE)) {
|
|
stop("Ckmeans.1d.dp package is required", call. = FALSE)
|
|
}
|
|
|
|
clusters <- suppressWarnings(
|
|
Ckmeans.1d.dp::Ckmeans.1d.dp(importance_matrix$Importance, n_clusters)
|
|
)
|
|
importance_matrix[, Cluster := as.character(clusters$cluster)]
|
|
|
|
plot <-
|
|
ggplot2::ggplot(importance_matrix,
|
|
ggplot2::aes(x = factor(Feature, levels = rev(Feature)), y = Importance, width = 0.5),
|
|
environment = environment()) +
|
|
ggplot2::geom_bar(ggplot2::aes(fill = Cluster), stat = "identity", position = "identity") +
|
|
ggplot2::coord_flip() +
|
|
ggplot2::xlab("Features") +
|
|
ggplot2::ggtitle("Feature importance") +
|
|
ggplot2::theme(plot.title = ggplot2::element_text(lineheight = .9, face = "bold"),
|
|
panel.grid.major.y = ggplot2::element_blank())
|
|
return(plot)
|
|
}
|
|
|
|
|
|
#' @rdname xgb.plot.deepness
|
|
#' @export
|
|
xgb.ggplot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med.depth", "med.weight")) {
|
|
|
|
if (!requireNamespace("ggplot2", quietly = TRUE))
|
|
stop("ggplot2 package is required for plotting the graph deepness.", call. = FALSE)
|
|
|
|
which <- match.arg(which)
|
|
|
|
dt_depths <- xgb.plot.deepness(model = model, plot = FALSE)
|
|
dt_summaries <- dt_depths[, .(.N, Cover = mean(Cover)), Depth]
|
|
setkey(dt_summaries, 'Depth')
|
|
|
|
if (which == "2x1") {
|
|
p1 <-
|
|
ggplot2::ggplot(dt_summaries) +
|
|
ggplot2::geom_bar(ggplot2::aes(x = Depth, y = N), stat = "Identity") +
|
|
ggplot2::xlab("") +
|
|
ggplot2::ylab("Number of leafs") +
|
|
ggplot2::ggtitle("Model complexity") +
|
|
ggplot2::theme(
|
|
plot.title = ggplot2::element_text(lineheight = 0.9, face = "bold"),
|
|
panel.grid.major.y = ggplot2::element_blank(),
|
|
axis.ticks = ggplot2::element_blank(),
|
|
axis.text.x = ggplot2::element_blank()
|
|
)
|
|
|
|
p2 <-
|
|
ggplot2::ggplot(dt_summaries) +
|
|
ggplot2::geom_bar(ggplot2::aes(x = Depth, y = Cover), stat = "Identity") +
|
|
ggplot2::xlab("Leaf depth") +
|
|
ggplot2::ylab("Weighted cover")
|
|
|
|
multiplot(p1, p2, cols = 1)
|
|
return(invisible(list(p1, p2)))
|
|
|
|
} else if (which == "max.depth") {
|
|
p <-
|
|
ggplot2::ggplot(dt_depths[, max(Depth), Tree]) +
|
|
ggplot2::geom_jitter(ggplot2::aes(x = Tree, y = V1),
|
|
height = 0.15, alpha = 0.4, size = 3, stroke = 0) +
|
|
ggplot2::xlab("tree #") +
|
|
ggplot2::ylab("Max tree leaf depth")
|
|
return(p)
|
|
|
|
} else if (which == "med.depth") {
|
|
p <-
|
|
ggplot2::ggplot(dt_depths[, median(as.numeric(Depth)), Tree]) +
|
|
ggplot2::geom_jitter(ggplot2::aes(x = Tree, y = V1),
|
|
height = 0.15, alpha = 0.4, size = 3, stroke = 0) +
|
|
ggplot2::xlab("tree #") +
|
|
ggplot2::ylab("Median tree leaf depth")
|
|
return(p)
|
|
|
|
} else if (which == "med.weight") {
|
|
p <-
|
|
ggplot2::ggplot(dt_depths[, median(abs(Weight)), Tree]) +
|
|
ggplot2::geom_point(ggplot2::aes(x = Tree, y = V1),
|
|
alpha = 0.4, size = 3, stroke = 0) +
|
|
ggplot2::xlab("tree #") +
|
|
ggplot2::ylab("Median absolute leaf weight")
|
|
return(p)
|
|
}
|
|
}
|
|
|
|
#' @rdname xgb.plot.shap.summary
|
|
#' @export
|
|
xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, top_n = 10, model = NULL,
|
|
trees = NULL, target_class = NULL, approxcontrib = FALSE, subsample = NULL) {
|
|
data_list <- xgb.shap.data(
|
|
data = data,
|
|
shap_contrib = shap_contrib,
|
|
features = features,
|
|
top_n = top_n,
|
|
model = model,
|
|
trees = trees,
|
|
target_class = target_class,
|
|
approxcontrib = approxcontrib,
|
|
subsample = subsample,
|
|
max_observations = 10000 # 10,000 samples per feature.
|
|
)
|
|
p_data <- prepare.ggplot.shap.data(data_list, normalize = TRUE)
|
|
# Reverse factor levels so that the first level is at the top of the plot
|
|
p_data[, "feature" := factor(feature, rev(levels(feature)))]
|
|
p <- ggplot2::ggplot(p_data, ggplot2::aes(x = feature, y = p_data$shap_value, colour = p_data$feature_value)) +
|
|
ggplot2::geom_jitter(alpha = 0.5, width = 0.1) +
|
|
ggplot2::scale_colour_viridis_c(limits = c(-3, 3), option = "plasma", direction = -1) +
|
|
ggplot2::geom_abline(slope = 0, intercept = 0, colour = "darkgrey") +
|
|
ggplot2::coord_flip()
|
|
|
|
p
|
|
}
|
|
|
|
#' Combine feature values and SHAP values
|
|
#'
|
|
#' Internal function used to combine and melt feature values and SHAP contributions
|
|
#' as required for ggplot functions related to SHAP.
|
|
#'
|
|
#' @param data_list The result of `xgb.shap.data()`.
|
|
#' @param normalize Whether to standardize feature values to mean 0 and
|
|
#' standard deviation 1. This is useful for comparing multiple features on the same
|
|
#' plot. Default is \code{FALSE}.
|
|
#'
|
|
#' @return A `data.table` containing the observation ID, the feature name, the
|
|
#' feature value (normalized if specified), and the SHAP contribution value.
|
|
#' @noRd
|
|
#' @keywords internal
|
|
prepare.ggplot.shap.data <- function(data_list, normalize = FALSE) {
|
|
data <- data_list[["data"]]
|
|
shap_contrib <- data_list[["shap_contrib"]]
|
|
|
|
data <- data.table::as.data.table(as.matrix(data))
|
|
if (normalize) {
|
|
data[, (names(data)) := lapply(.SD, normalize)]
|
|
}
|
|
data[, "id" := seq_len(nrow(data))]
|
|
data_m <- data.table::melt.data.table(data, id.vars = "id", variable.name = "feature", value.name = "feature_value")
|
|
|
|
shap_contrib <- data.table::as.data.table(as.matrix(shap_contrib))
|
|
shap_contrib[, "id" := seq_len(nrow(shap_contrib))]
|
|
shap_contrib_m <- data.table::melt.data.table(shap_contrib, id.vars = "id", variable.name = "feature", value.name = "shap_value")
|
|
|
|
p_data <- data.table::merge.data.table(data_m, shap_contrib_m, by = c("id", "feature"))
|
|
|
|
p_data
|
|
}
|
|
|
|
#' Scale feature values
|
|
#'
|
|
#' Internal function that scales feature values to mean 0 and standard deviation 1.
|
|
#' Useful to compare multiple features on the same plot.
|
|
#'
|
|
#' @param x Numeric vector.
|
|
#'
|
|
#' @return Numeric vector with mean 0 and standard deviation 1.
|
|
#' @noRd
|
|
#' @keywords internal
|
|
normalize <- function(x) {
|
|
loc <- mean(x, na.rm = TRUE)
|
|
scale <- stats::sd(x, na.rm = TRUE)
|
|
|
|
(x - loc) / scale
|
|
}
|
|
|
|
# Plot multiple ggplot graph aligned by rows and columns.
|
|
# ... the plots
|
|
# cols number of columns
|
|
# internal utility function
|
|
multiplot <- function(..., cols) {
|
|
plots <- list(...)
|
|
num_plots <- length(plots)
|
|
|
|
layout <- matrix(seq(1, cols * ceiling(num_plots / cols)),
|
|
ncol = cols, nrow = ceiling(num_plots / cols))
|
|
|
|
if (num_plots == 1) {
|
|
print(plots[[1]])
|
|
} else {
|
|
grid::grid.newpage()
|
|
grid::pushViewport(grid::viewport(layout = grid::grid.layout(nrow(layout), ncol(layout))))
|
|
for (i in 1:num_plots) {
|
|
# Get the i,j matrix positions of the regions that contain this subplot
|
|
matchidx <- as.data.table(which(layout == i, arr.ind = TRUE))
|
|
|
|
print(
|
|
plots[[i]], vp = grid::viewport(
|
|
layout.pos.row = matchidx$row,
|
|
layout.pos.col = matchidx$col
|
|
)
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
globalVariables(c(
|
|
"Cluster", "ggplot", "aes", "geom_bar", "coord_flip", "xlab", "ylab", "ggtitle", "theme",
|
|
"element_blank", "element_text", "V1", "Weight", "feature"
|
|
))
|