Add SHAP summary plot using ggplot2 (#5882)
* add SHAP summary plot using ggplot2 * Update xgb.plot.shap * Update example in xgb.plot.shap documentation * update logic, add tests * whitespace fixes * whitespace fixes for test_helpers * namespace for sd function * explicitly declare variables that are automatically evaluated by data.table * Fix R lint Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -99,6 +99,85 @@ xgb.ggplot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med
|
||||
}
|
||||
}
|
||||
|
||||
#' @rdname xgb.plot.shap.summary
|
||||
#' @export
|
||||
xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, top_n = 10, model = NULL,
|
||||
trees = NULL, target_class = NULL, approxcontrib = FALSE, subsample = NULL) {
|
||||
data_list <- xgb.shap.data(
|
||||
data = data,
|
||||
shap_contrib = shap_contrib,
|
||||
features = features,
|
||||
top_n = top_n,
|
||||
model = model,
|
||||
trees = trees,
|
||||
target_class = target_class,
|
||||
approxcontrib = approxcontrib,
|
||||
subsample = subsample,
|
||||
max_observations = 10000 # 10,000 samples per feature.
|
||||
)
|
||||
p_data <- prepare.ggplot.shap.data(data_list, normalize = TRUE)
|
||||
# Reverse factor levels so that the first level is at the top of the plot
|
||||
p_data[, "feature" := factor(feature, rev(levels(feature)))]
|
||||
|
||||
p <- ggplot2::ggplot(p_data, ggplot2::aes(x = feature, y = shap_value, colour = feature_value)) +
|
||||
ggplot2::geom_jitter(alpha = 0.5, width = 0.1) +
|
||||
ggplot2::scale_colour_viridis_c(limits = c(-3, 3), option = "plasma", direction = -1) +
|
||||
ggplot2::geom_abline(slope = 0, intercept = 0, colour = "darkgrey") +
|
||||
ggplot2::coord_flip()
|
||||
|
||||
p
|
||||
}
|
||||
|
||||
#' Combine and melt feature values and SHAP contributions for sample
|
||||
#' observations.
|
||||
#'
|
||||
#' Conforms to data format required for ggplot functions.
|
||||
#'
|
||||
#' Internal utility function.
|
||||
#'
|
||||
#' @param data_list List containing 'data' and 'shap_contrib' returned by
|
||||
#' \code{xgb.shap.data()}.
|
||||
#' @param normalize Whether to standardize feature values to have mean 0 and
|
||||
#' standard deviation 1 (useful for comparing multiple features on the same
|
||||
#' plot). Default \code{FALSE}.
|
||||
#'
|
||||
#' @return A data.table containing the observation ID, the feature name, the
|
||||
#' feature value (normalized if specified), and the SHAP contribution value.
|
||||
prepare.ggplot.shap.data <- function(data_list, normalize = FALSE) {
|
||||
data <- data_list[["data"]]
|
||||
shap_contrib <- data_list[["shap_contrib"]]
|
||||
|
||||
data <- data.table::as.data.table(as.matrix(data))
|
||||
if (normalize) {
|
||||
data[, (names(data)) := lapply(.SD, normalize)]
|
||||
}
|
||||
data[, "id" := seq_len(nrow(data))]
|
||||
data_m <- data.table::melt.data.table(data, id.vars = "id", variable.name = "feature", value.name = "feature_value")
|
||||
|
||||
shap_contrib <- data.table::as.data.table(as.matrix(shap_contrib))
|
||||
shap_contrib[, "id" := seq_len(nrow(shap_contrib))]
|
||||
shap_contrib_m <- data.table::melt.data.table(shap_contrib, id.vars = "id", variable.name = "feature", value.name = "shap_value")
|
||||
|
||||
p_data <- data.table::merge.data.table(data_m, shap_contrib_m, by = c("id", "feature"))
|
||||
|
||||
p_data
|
||||
}
|
||||
|
||||
#' Scale feature value to have mean 0, standard deviation 1
|
||||
#'
|
||||
#' This is used to compare multiple features on the same plot.
|
||||
#' Internal utility function
|
||||
#'
|
||||
#' @param x Numeric vector
|
||||
#'
|
||||
#' @return Numeric vector with mean 0 and sd 1.
|
||||
normalize <- function(x) {
|
||||
loc <- mean(x, na.rm = TRUE)
|
||||
scale <- stats::sd(x, na.rm = TRUE)
|
||||
|
||||
(x - loc) / scale
|
||||
}
|
||||
|
||||
# Plot multiple ggplot graph aligned by rows and columns.
|
||||
# ... the plots
|
||||
# cols number of columns
|
||||
@@ -131,5 +210,5 @@ multiplot <- function(..., cols = 1) {
|
||||
|
||||
globalVariables(c(
|
||||
"Cluster", "ggplot", "aes", "geom_bar", "coord_flip", "xlab", "ylab", "ggtitle", "theme",
|
||||
"element_blank", "element_text", "V1", "Weight"
|
||||
"element_blank", "element_text", "V1", "Weight", "feature"
|
||||
))
|
||||
|
||||
Reference in New Issue
Block a user