Add SHAP summary plot using ggplot2 (#5882)
* add SHAP summary plot using ggplot2 * Update xgb.plot.shap * Update example in xgb.plot.shap documentation * update logic, add tests * whitespace fixes * whitespace fixes for test_helpers * namespace for sd function * explicitly declare variables that are automatically evaluated by data.table * Fix R lint Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
989ddd036f
commit
e51cba6195
@ -99,6 +99,85 @@ xgb.ggplot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#' @rdname xgb.plot.shap.summary
|
||||||
|
#' @export
|
||||||
|
xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, top_n = 10, model = NULL,
|
||||||
|
trees = NULL, target_class = NULL, approxcontrib = FALSE, subsample = NULL) {
|
||||||
|
data_list <- xgb.shap.data(
|
||||||
|
data = data,
|
||||||
|
shap_contrib = shap_contrib,
|
||||||
|
features = features,
|
||||||
|
top_n = top_n,
|
||||||
|
model = model,
|
||||||
|
trees = trees,
|
||||||
|
target_class = target_class,
|
||||||
|
approxcontrib = approxcontrib,
|
||||||
|
subsample = subsample,
|
||||||
|
max_observations = 10000 # 10,000 samples per feature.
|
||||||
|
)
|
||||||
|
p_data <- prepare.ggplot.shap.data(data_list, normalize = TRUE)
|
||||||
|
# Reverse factor levels so that the first level is at the top of the plot
|
||||||
|
p_data[, "feature" := factor(feature, rev(levels(feature)))]
|
||||||
|
|
||||||
|
p <- ggplot2::ggplot(p_data, ggplot2::aes(x = feature, y = shap_value, colour = feature_value)) +
|
||||||
|
ggplot2::geom_jitter(alpha = 0.5, width = 0.1) +
|
||||||
|
ggplot2::scale_colour_viridis_c(limits = c(-3, 3), option = "plasma", direction = -1) +
|
||||||
|
ggplot2::geom_abline(slope = 0, intercept = 0, colour = "darkgrey") +
|
||||||
|
ggplot2::coord_flip()
|
||||||
|
|
||||||
|
p
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Combine and melt feature values and SHAP contributions for sample
|
||||||
|
#' observations.
|
||||||
|
#'
|
||||||
|
#' Conforms to data format required for ggplot functions.
|
||||||
|
#'
|
||||||
|
#' Internal utility function.
|
||||||
|
#'
|
||||||
|
#' @param data_list List containing 'data' and 'shap_contrib' returned by
|
||||||
|
#' \code{xgb.shap.data()}.
|
||||||
|
#' @param normalize Whether to standardize feature values to have mean 0 and
|
||||||
|
#' standard deviation 1 (useful for comparing multiple features on the same
|
||||||
|
#' plot). Default \code{FALSE}.
|
||||||
|
#'
|
||||||
|
#' @return A data.table containing the observation ID, the feature name, the
|
||||||
|
#' feature value (normalized if specified), and the SHAP contribution value.
|
||||||
|
prepare.ggplot.shap.data <- function(data_list, normalize = FALSE) {
|
||||||
|
data <- data_list[["data"]]
|
||||||
|
shap_contrib <- data_list[["shap_contrib"]]
|
||||||
|
|
||||||
|
data <- data.table::as.data.table(as.matrix(data))
|
||||||
|
if (normalize) {
|
||||||
|
data[, (names(data)) := lapply(.SD, normalize)]
|
||||||
|
}
|
||||||
|
data[, "id" := seq_len(nrow(data))]
|
||||||
|
data_m <- data.table::melt.data.table(data, id.vars = "id", variable.name = "feature", value.name = "feature_value")
|
||||||
|
|
||||||
|
shap_contrib <- data.table::as.data.table(as.matrix(shap_contrib))
|
||||||
|
shap_contrib[, "id" := seq_len(nrow(shap_contrib))]
|
||||||
|
shap_contrib_m <- data.table::melt.data.table(shap_contrib, id.vars = "id", variable.name = "feature", value.name = "shap_value")
|
||||||
|
|
||||||
|
p_data <- data.table::merge.data.table(data_m, shap_contrib_m, by = c("id", "feature"))
|
||||||
|
|
||||||
|
p_data
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Scale feature value to have mean 0, standard deviation 1
|
||||||
|
#'
|
||||||
|
#' This is used to compare multiple features on the same plot.
|
||||||
|
#' Internal utility function
|
||||||
|
#'
|
||||||
|
#' @param x Numeric vector
|
||||||
|
#'
|
||||||
|
#' @return Numeric vector with mean 0 and sd 1.
|
||||||
|
normalize <- function(x) {
|
||||||
|
loc <- mean(x, na.rm = TRUE)
|
||||||
|
scale <- stats::sd(x, na.rm = TRUE)
|
||||||
|
|
||||||
|
(x - loc) / scale
|
||||||
|
}
|
||||||
|
|
||||||
# Plot multiple ggplot graph aligned by rows and columns.
|
# Plot multiple ggplot graph aligned by rows and columns.
|
||||||
# ... the plots
|
# ... the plots
|
||||||
# cols number of columns
|
# cols number of columns
|
||||||
@ -131,5 +210,5 @@ multiplot <- function(..., cols = 1) {
|
|||||||
|
|
||||||
globalVariables(c(
|
globalVariables(c(
|
||||||
"Cluster", "ggplot", "aes", "geom_bar", "coord_flip", "xlab", "ylab", "ggtitle", "theme",
|
"Cluster", "ggplot", "aes", "geom_bar", "coord_flip", "xlab", "ylab", "ggtitle", "theme",
|
||||||
"element_blank", "element_text", "V1", "Weight"
|
"element_blank", "element_text", "V1", "Weight", "feature"
|
||||||
))
|
))
|
||||||
|
|||||||
@ -81,6 +81,7 @@
|
|||||||
#' xgb.plot.shap(agaricus.test$data, model = bst, features = "odor=none")
|
#' xgb.plot.shap(agaricus.test$data, model = bst, features = "odor=none")
|
||||||
#' contr <- predict(bst, agaricus.test$data, predcontrib = TRUE)
|
#' contr <- predict(bst, agaricus.test$data, predcontrib = TRUE)
|
||||||
#' xgb.plot.shap(agaricus.test$data, contr, model = bst, top_n = 12, n_col = 3)
|
#' xgb.plot.shap(agaricus.test$data, contr, model = bst, top_n = 12, n_col = 3)
|
||||||
|
#' xgb.ggplot.shap.summary(agaricus.test$data, contr, model = bst, top_n = 12) # Summary plot
|
||||||
#'
|
#'
|
||||||
#' # multiclass example - plots for each class separately:
|
#' # multiclass example - plots for each class separately:
|
||||||
#' nclass <- 3
|
#' nclass <- 3
|
||||||
@ -99,6 +100,7 @@
|
|||||||
#' n_col = 2, col = col, pch = 16, pch_NA = 17)
|
#' n_col = 2, col = col, pch = 16, pch_NA = 17)
|
||||||
#' xgb.plot.shap(x, model = mbst, trees = trees0 + 2, target_class = 2, top_n = 4,
|
#' xgb.plot.shap(x, model = mbst, trees = trees0 + 2, target_class = 2, top_n = 4,
|
||||||
#' n_col = 2, col = col, pch = 16, pch_NA = 17)
|
#' n_col = 2, col = col, pch = 16, pch_NA = 17)
|
||||||
|
#' xgb.ggplot.shap.summary(x, model = mbst, target_class = 0, top_n = 4) # Summary plot
|
||||||
#'
|
#'
|
||||||
#' @rdname xgb.plot.shap
|
#' @rdname xgb.plot.shap
|
||||||
#' @export
|
#' @export
|
||||||
@ -109,69 +111,33 @@ xgb.plot.shap <- function(data, shap_contrib = NULL, features = NULL, top_n = 1,
|
|||||||
plot_NA = TRUE, col_NA = rgb(0.7, 0, 1, 0.6), pch_NA = '.', pos_NA = 1.07,
|
plot_NA = TRUE, col_NA = rgb(0.7, 0, 1, 0.6), pch_NA = '.', pos_NA = 1.07,
|
||||||
plot_loess = TRUE, col_loess = 2, span_loess = 0.5,
|
plot_loess = TRUE, col_loess = 2, span_loess = 0.5,
|
||||||
which = c("1d", "2d"), plot = TRUE, ...) {
|
which = c("1d", "2d"), plot = TRUE, ...) {
|
||||||
|
data_list <- xgb.shap.data(
|
||||||
if (!is.matrix(data) && !inherits(data, "dgCMatrix"))
|
data = data,
|
||||||
stop("data: must be either matrix or dgCMatrix")
|
shap_contrib = shap_contrib,
|
||||||
|
features = features,
|
||||||
if (is.null(shap_contrib) && (is.null(model) || !inherits(model, "xgb.Booster")))
|
top_n = top_n,
|
||||||
stop("when shap_contrib is not provided, one must provide an xgb.Booster model")
|
model = model,
|
||||||
|
trees = trees,
|
||||||
if (is.null(features) && (is.null(model) || !inherits(model, "xgb.Booster")))
|
target_class = target_class,
|
||||||
stop("when features are not provided, one must provide an xgb.Booster model to rank the features")
|
approxcontrib = approxcontrib,
|
||||||
|
subsample = subsample,
|
||||||
if (!is.null(shap_contrib) &&
|
max_observations = 100000
|
||||||
(!is.matrix(shap_contrib) || nrow(shap_contrib) != nrow(data) || ncol(shap_contrib) != ncol(data) + 1))
|
)
|
||||||
stop("shap_contrib is not compatible with the provided data")
|
data <- data_list[["data"]]
|
||||||
|
shap_contrib <- data_list[["shap_contrib"]]
|
||||||
nsample <- if (is.null(subsample)) min(100000, nrow(data)) else as.integer(subsample * nrow(data))
|
features <- colnames(data)
|
||||||
idx <- sample(seq_len(nrow(data)), nsample)
|
|
||||||
data <- data[idx, ]
|
|
||||||
|
|
||||||
if (is.null(shap_contrib)) {
|
|
||||||
shap_contrib <- predict(model, data, predcontrib = TRUE, approxcontrib = approxcontrib)
|
|
||||||
} else {
|
|
||||||
shap_contrib <- shap_contrib[idx, ]
|
|
||||||
}
|
|
||||||
|
|
||||||
which <- match.arg(which)
|
which <- match.arg(which)
|
||||||
if (which == "2d")
|
if (which == "2d")
|
||||||
stop("2D plots are not implemented yet")
|
stop("2D plots are not implemented yet")
|
||||||
|
|
||||||
if (is.null(features)) {
|
|
||||||
imp <- xgb.importance(model = model, trees = trees)
|
|
||||||
top_n <- as.integer(top_n[1])
|
|
||||||
if (top_n < 1 && top_n > 100)
|
|
||||||
stop("top_n: must be an integer within [1, 100]")
|
|
||||||
features <- imp$Feature[1:min(top_n, NROW(imp))]
|
|
||||||
}
|
|
||||||
|
|
||||||
if (is.character(features)) {
|
|
||||||
if (is.null(colnames(data)))
|
|
||||||
stop("Either provide `data` with column names or provide `features` as column indices")
|
|
||||||
features <- match(features, colnames(data))
|
|
||||||
}
|
|
||||||
|
|
||||||
if (n_col > length(features)) n_col <- length(features)
|
if (n_col > length(features)) n_col <- length(features)
|
||||||
|
|
||||||
if (is.list(shap_contrib)) { # multiclass: either choose a class or merge
|
|
||||||
shap_contrib <- if (!is.null(target_class)) shap_contrib[[target_class + 1]]
|
|
||||||
else Reduce("+", lapply(shap_contrib, abs))
|
|
||||||
}
|
|
||||||
|
|
||||||
shap_contrib <- shap_contrib[, features, drop = FALSE]
|
|
||||||
data <- data[, features, drop = FALSE]
|
|
||||||
cols <- colnames(data)
|
|
||||||
if (is.null(cols)) cols <- colnames(shap_contrib)
|
|
||||||
if (is.null(cols)) cols <- paste0('X', seq_len(ncol(data)))
|
|
||||||
colnames(data) <- cols
|
|
||||||
colnames(shap_contrib) <- cols
|
|
||||||
|
|
||||||
if (plot && which == "1d") {
|
if (plot && which == "1d") {
|
||||||
op <- par(mfrow = c(ceiling(length(features) / n_col), n_col),
|
op <- par(mfrow = c(ceiling(length(features) / n_col), n_col),
|
||||||
oma = c(0, 0, 0, 0) + 0.2,
|
oma = c(0, 0, 0, 0) + 0.2,
|
||||||
mar = c(3.5, 3.5, 0, 0) + 0.1,
|
mar = c(3.5, 3.5, 0, 0) + 0.1,
|
||||||
mgp = c(1.7, 0.6, 0))
|
mgp = c(1.7, 0.6, 0))
|
||||||
for (f in cols) {
|
for (f in features) {
|
||||||
ord <- order(data[, f])
|
ord <- order(data[, f])
|
||||||
x <- data[, f][ord]
|
x <- data[, f][ord]
|
||||||
y <- shap_contrib[, f][ord]
|
y <- shap_contrib[, f][ord]
|
||||||
@ -216,3 +182,105 @@ xgb.plot.shap <- function(data, shap_contrib = NULL, features = NULL, top_n = 1,
|
|||||||
}
|
}
|
||||||
invisible(list(data = data, shap_contrib = shap_contrib))
|
invisible(list(data = data, shap_contrib = shap_contrib))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#' SHAP contribution dependency summary plot
|
||||||
|
#'
|
||||||
|
#' Compare SHAP contributions of different features.
|
||||||
|
#'
|
||||||
|
#' A point plot (each point representing one sample from \code{data}) is
|
||||||
|
#' produced for each feature, with the points plotted on the SHAP value axis.
|
||||||
|
#' Each point (observation) is coloured based on its feature value. The plot
|
||||||
|
#' hence allows us to see which features have a negative / positive contribution
|
||||||
|
#' on the model prediction, and whether the contribution is different for larger
|
||||||
|
#' or smaller values of the feature. We effectively try to replicate the
|
||||||
|
#' \code{summary_plot} function from https://github.com/slundberg/shap.
|
||||||
|
#'
|
||||||
|
#' @inheritParams xgb.plot.shap
|
||||||
|
#'
|
||||||
|
#' @return A \code{ggplot2} object.
|
||||||
|
#' @export
|
||||||
|
#'
|
||||||
|
#' @examples See \code{\link{xgb.plot.shap}}.
|
||||||
|
#' @seealso \code{\link{xgb.plot.shap}}, \code{\link{xgb.ggplot.shap.summary}},
|
||||||
|
#' \code{\url{https://github.com/slundberg/shap}}
|
||||||
|
xgb.plot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, top_n = 10, model = NULL,
|
||||||
|
trees = NULL, target_class = NULL, approxcontrib = FALSE, subsample = NULL) {
|
||||||
|
# Only ggplot implementation is available.
|
||||||
|
xgb.ggplot.shap.summary(data, shap_contrib, features, top_n, model, trees, target_class, approxcontrib, subsample)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Prepare data for SHAP plots. To be used in xgb.plot.shap, xgb.plot.shap.summary, etc.
|
||||||
|
#' Internal utility function.
|
||||||
|
#'
|
||||||
|
#' @return A list containing: 'data', a matrix containing sample observations
|
||||||
|
#' and their feature values; 'shap_contrib', a matrix containing the SHAP contribution
|
||||||
|
#' values for these observations.
|
||||||
|
xgb.shap.data <- function(data, shap_contrib = NULL, features = NULL, top_n = 1, model = NULL,
|
||||||
|
trees = NULL, target_class = NULL, approxcontrib = FALSE,
|
||||||
|
subsample = NULL, max_observations = 100000) {
|
||||||
|
if (!is.matrix(data) && !inherits(data, "dgCMatrix"))
|
||||||
|
stop("data: must be either matrix or dgCMatrix")
|
||||||
|
|
||||||
|
if (is.null(shap_contrib) && (is.null(model) || !inherits(model, "xgb.Booster")))
|
||||||
|
stop("when shap_contrib is not provided, one must provide an xgb.Booster model")
|
||||||
|
|
||||||
|
if (is.null(features) && (is.null(model) || !inherits(model, "xgb.Booster")))
|
||||||
|
stop("when features are not provided, one must provide an xgb.Booster model to rank the features")
|
||||||
|
|
||||||
|
if (!is.null(shap_contrib) &&
|
||||||
|
(!is.matrix(shap_contrib) || nrow(shap_contrib) != nrow(data) || ncol(shap_contrib) != ncol(data) + 1))
|
||||||
|
stop("shap_contrib is not compatible with the provided data")
|
||||||
|
|
||||||
|
if (is.character(features) && is.null(colnames(data)))
|
||||||
|
stop("either provide `data` with column names or provide `features` as column indices")
|
||||||
|
|
||||||
|
if (is.null(model$feature_names) && model$nfeatures != ncol(data))
|
||||||
|
stop("if model has no feature_names, columns in `data` must match features in model")
|
||||||
|
|
||||||
|
if (!is.null(subsample)) {
|
||||||
|
idx <- sample(x = seq_len(nrow(data)), size = as.integer(subsample * nrow(data)), replace = FALSE)
|
||||||
|
} else {
|
||||||
|
idx <- seq_len(min(nrow(data), max_observations))
|
||||||
|
}
|
||||||
|
data <- data[idx, ]
|
||||||
|
if (is.null(colnames(data))) {
|
||||||
|
colnames(data) <- paste0("X", seq_len(ncol(data)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is.null(shap_contrib)) {
|
||||||
|
if (is.list(shap_contrib)) { # multiclass: either choose a class or merge
|
||||||
|
shap_contrib <- if (!is.null(target_class)) shap_contrib[[target_class + 1]] else Reduce("+", lapply(shap_contrib, abs))
|
||||||
|
}
|
||||||
|
shap_contrib <- shap_contrib[idx, ]
|
||||||
|
if (is.null(colnames(shap_contrib))) {
|
||||||
|
colnames(shap_contrib) <- paste0("X", seq_len(ncol(data)))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
shap_contrib <- predict(model, newdata = data, predcontrib = TRUE, approxcontrib = approxcontrib)
|
||||||
|
if (is.list(shap_contrib)) { # multiclass: either choose a class or merge
|
||||||
|
shap_contrib <- if (!is.null(target_class)) shap_contrib[[target_class + 1]] else Reduce("+", lapply(shap_contrib, abs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is.null(features)) {
|
||||||
|
if (!is.null(model$feature_names)) {
|
||||||
|
imp <- xgb.importance(model = model, trees = trees)
|
||||||
|
} else {
|
||||||
|
imp <- xgb.importance(model = model, trees = trees, feature_names = colnames(data))
|
||||||
|
}
|
||||||
|
top_n <- top_n[1]
|
||||||
|
if (top_n < 1 | top_n > 100) stop("top_n: must be an integer within [1, 100]")
|
||||||
|
features <- imp$Feature[1:min(top_n, NROW(imp))]
|
||||||
|
}
|
||||||
|
if (is.character(features)) {
|
||||||
|
features <- match(features, colnames(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
shap_contrib <- shap_contrib[, features, drop = FALSE]
|
||||||
|
data <- data[, features, drop = FALSE]
|
||||||
|
|
||||||
|
list(
|
||||||
|
data = data,
|
||||||
|
shap_contrib = shap_contrib
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|||||||
@ -351,11 +351,47 @@ test_that("xgb.plot.deepness works", {
|
|||||||
xgb.ggplot.deepness(model = bst.Tree)
|
xgb.ggplot.deepness(model = bst.Tree)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("xgb.shap.data works when top_n is provided", {
|
||||||
|
data_list <- xgb.shap.data(data = sparse_matrix, model = bst.Tree, top_n = 2)
|
||||||
|
expect_equal(names(data_list), c("data", "shap_contrib"))
|
||||||
|
expect_equal(NCOL(data_list$data), 2)
|
||||||
|
expect_equal(NCOL(data_list$shap_contrib), 2)
|
||||||
|
expect_equal(NROW(data_list$data), NROW(data_list$shap_contrib))
|
||||||
|
expect_gt(length(colnames(data_list$data)), 0)
|
||||||
|
expect_gt(length(colnames(data_list$shap_contrib)), 0)
|
||||||
|
|
||||||
|
# for multiclass without target class provided
|
||||||
|
data_list <- xgb.shap.data(data = as.matrix(iris[, -5]), model = mbst.Tree, top_n = 2)
|
||||||
|
expect_equal(dim(data_list$shap_contrib), c(nrow(iris), 2))
|
||||||
|
# for multiclass with target class provided
|
||||||
|
data_list <- xgb.shap.data(data = as.matrix(iris[, -5]), model = mbst.Tree, top_n = 2, target_class = 0)
|
||||||
|
expect_equal(dim(data_list$shap_contrib), c(nrow(iris), 2))
|
||||||
|
})
|
||||||
|
|
||||||
|
test_that("xgb.shap.data works with subsampling", {
|
||||||
|
data_list <- xgb.shap.data(data = sparse_matrix, model = bst.Tree, top_n = 2, subsample = 0.8)
|
||||||
|
expect_equal(NROW(data_list$data), as.integer(0.8 * nrow(sparse_matrix)))
|
||||||
|
expect_equal(NROW(data_list$data), NROW(data_list$shap_contrib))
|
||||||
|
})
|
||||||
|
|
||||||
|
test_that("prepare.ggplot.shap.data works", {
|
||||||
|
data_list <- xgb.shap.data(data = sparse_matrix, model = bst.Tree, top_n = 2)
|
||||||
|
plot_data <- prepare.ggplot.shap.data(data_list, normalize = TRUE)
|
||||||
|
expect_s3_class(plot_data, "data.frame")
|
||||||
|
expect_equal(names(plot_data), c("id", "feature", "feature_value", "shap_value"))
|
||||||
|
expect_s3_class(plot_data$feature, "factor")
|
||||||
|
# Each observation should have 1 row for each feature
|
||||||
|
expect_equal(nrow(plot_data), nrow(sparse_matrix) * 2)
|
||||||
|
})
|
||||||
|
|
||||||
test_that("xgb.plot.shap works", {
|
test_that("xgb.plot.shap works", {
|
||||||
sh <- xgb.plot.shap(data = sparse_matrix, model = bst.Tree, top_n = 2, col = 4)
|
sh <- xgb.plot.shap(data = sparse_matrix, model = bst.Tree, top_n = 2, col = 4)
|
||||||
expect_equal(names(sh), c("data", "shap_contrib"))
|
expect_equal(names(sh), c("data", "shap_contrib"))
|
||||||
expect_equal(NCOL(sh$data), 2)
|
})
|
||||||
expect_equal(NCOL(sh$shap_contrib), 2)
|
|
||||||
|
test_that("xgb.plot.shap.summary works", {
|
||||||
|
xgb.plot.shap.summary(data = sparse_matrix, model = bst.Tree, top_n = 2)
|
||||||
|
xgb.ggplot.shap.summary(data = sparse_matrix, model = bst.Tree, top_n = 2)
|
||||||
})
|
})
|
||||||
|
|
||||||
test_that("check.deprecation works", {
|
test_that("check.deprecation works", {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user