Add SHAP summary plot using ggplot2 (#5882)

* add SHAP summary plot using ggplot2

* Update xgb.plot.shap

* Update example in xgb.plot.shap documentation

* update logic, add tests

* whitespace fixes

* whitespace fixes for test_helpers

* namespace for sd function

* explicitly declare variables that are automatically evaluated by data.table

* Fix R lint

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Cuong Duong 2020-08-19 11:04:09 +10:00 committed by GitHub
parent 989ddd036f
commit e51cba6195
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 238 additions and 55 deletions

View File

@ -99,6 +99,85 @@ xgb.ggplot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med
}
}
#' @rdname xgb.plot.shap.summary
#' @export
xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, top_n = 10, model = NULL,
trees = NULL, target_class = NULL, approxcontrib = FALSE, subsample = NULL) {
data_list <- xgb.shap.data(
data = data,
shap_contrib = shap_contrib,
features = features,
top_n = top_n,
model = model,
trees = trees,
target_class = target_class,
approxcontrib = approxcontrib,
subsample = subsample,
max_observations = 10000 # 10,000 samples per feature.
)
p_data <- prepare.ggplot.shap.data(data_list, normalize = TRUE)
# Reverse factor levels so that the first level is at the top of the plot
p_data[, "feature" := factor(feature, rev(levels(feature)))]
p <- ggplot2::ggplot(p_data, ggplot2::aes(x = feature, y = shap_value, colour = feature_value)) +
ggplot2::geom_jitter(alpha = 0.5, width = 0.1) +
ggplot2::scale_colour_viridis_c(limits = c(-3, 3), option = "plasma", direction = -1) +
ggplot2::geom_abline(slope = 0, intercept = 0, colour = "darkgrey") +
ggplot2::coord_flip()
p
}
#' Combine and melt feature values and SHAP contributions for sample
#' observations.
#'
#' Conforms to data format required for ggplot functions.
#'
#' Internal utility function.
#'
#' @param data_list List containing 'data' and 'shap_contrib' returned by
#' \code{xgb.shap.data()}.
#' @param normalize Whether to standardize feature values to have mean 0 and
#' standard deviation 1 (useful for comparing multiple features on the same
#' plot). Default \code{FALSE}.
#'
#' @return A data.table containing the observation ID, the feature name, the
#' feature value (normalized if specified), and the SHAP contribution value.
prepare.ggplot.shap.data <- function(data_list, normalize = FALSE) {
data <- data_list[["data"]]
shap_contrib <- data_list[["shap_contrib"]]
data <- data.table::as.data.table(as.matrix(data))
if (normalize) {
data[, (names(data)) := lapply(.SD, normalize)]
}
data[, "id" := seq_len(nrow(data))]
data_m <- data.table::melt.data.table(data, id.vars = "id", variable.name = "feature", value.name = "feature_value")
shap_contrib <- data.table::as.data.table(as.matrix(shap_contrib))
shap_contrib[, "id" := seq_len(nrow(shap_contrib))]
shap_contrib_m <- data.table::melt.data.table(shap_contrib, id.vars = "id", variable.name = "feature", value.name = "shap_value")
p_data <- data.table::merge.data.table(data_m, shap_contrib_m, by = c("id", "feature"))
p_data
}
#' Scale feature value to have mean 0, standard deviation 1
#'
#' This is used to compare multiple features on the same plot.
#' Internal utility function
#'
#' @param x Numeric vector
#'
#' @return Numeric vector with mean 0 and sd 1.
normalize <- function(x) {
loc <- mean(x, na.rm = TRUE)
scale <- stats::sd(x, na.rm = TRUE)
(x - loc) / scale
}
# Plot multiple ggplot graph aligned by rows and columns.
# ... the plots
# cols number of columns
@ -131,5 +210,5 @@ multiplot <- function(..., cols = 1) {
globalVariables(c(
"Cluster", "ggplot", "aes", "geom_bar", "coord_flip", "xlab", "ylab", "ggtitle", "theme",
"element_blank", "element_text", "V1", "Weight"
"element_blank", "element_text", "V1", "Weight", "feature"
))

View File

@ -81,6 +81,7 @@
#' xgb.plot.shap(agaricus.test$data, model = bst, features = "odor=none")
#' contr <- predict(bst, agaricus.test$data, predcontrib = TRUE)
#' xgb.plot.shap(agaricus.test$data, contr, model = bst, top_n = 12, n_col = 3)
#' xgb.ggplot.shap.summary(agaricus.test$data, contr, model = bst, top_n = 12) # Summary plot
#'
#' # multiclass example - plots for each class separately:
#' nclass <- 3
@ -99,6 +100,7 @@
#' n_col = 2, col = col, pch = 16, pch_NA = 17)
#' xgb.plot.shap(x, model = mbst, trees = trees0 + 2, target_class = 2, top_n = 4,
#' n_col = 2, col = col, pch = 16, pch_NA = 17)
#' xgb.ggplot.shap.summary(x, model = mbst, target_class = 0, top_n = 4) # Summary plot
#'
#' @rdname xgb.plot.shap
#' @export
@ -109,69 +111,33 @@ xgb.plot.shap <- function(data, shap_contrib = NULL, features = NULL, top_n = 1,
plot_NA = TRUE, col_NA = rgb(0.7, 0, 1, 0.6), pch_NA = '.', pos_NA = 1.07,
plot_loess = TRUE, col_loess = 2, span_loess = 0.5,
which = c("1d", "2d"), plot = TRUE, ...) {
if (!is.matrix(data) && !inherits(data, "dgCMatrix"))
stop("data: must be either matrix or dgCMatrix")
if (is.null(shap_contrib) && (is.null(model) || !inherits(model, "xgb.Booster")))
stop("when shap_contrib is not provided, one must provide an xgb.Booster model")
if (is.null(features) && (is.null(model) || !inherits(model, "xgb.Booster")))
stop("when features are not provided, one must provide an xgb.Booster model to rank the features")
if (!is.null(shap_contrib) &&
(!is.matrix(shap_contrib) || nrow(shap_contrib) != nrow(data) || ncol(shap_contrib) != ncol(data) + 1))
stop("shap_contrib is not compatible with the provided data")
nsample <- if (is.null(subsample)) min(100000, nrow(data)) else as.integer(subsample * nrow(data))
idx <- sample(seq_len(nrow(data)), nsample)
data <- data[idx, ]
if (is.null(shap_contrib)) {
shap_contrib <- predict(model, data, predcontrib = TRUE, approxcontrib = approxcontrib)
} else {
shap_contrib <- shap_contrib[idx, ]
}
data_list <- xgb.shap.data(
data = data,
shap_contrib = shap_contrib,
features = features,
top_n = top_n,
model = model,
trees = trees,
target_class = target_class,
approxcontrib = approxcontrib,
subsample = subsample,
max_observations = 100000
)
data <- data_list[["data"]]
shap_contrib <- data_list[["shap_contrib"]]
features <- colnames(data)
which <- match.arg(which)
if (which == "2d")
stop("2D plots are not implemented yet")
if (is.null(features)) {
imp <- xgb.importance(model = model, trees = trees)
top_n <- as.integer(top_n[1])
if (top_n < 1 && top_n > 100)
stop("top_n: must be an integer within [1, 100]")
features <- imp$Feature[1:min(top_n, NROW(imp))]
}
if (is.character(features)) {
if (is.null(colnames(data)))
stop("Either provide `data` with column names or provide `features` as column indices")
features <- match(features, colnames(data))
}
if (n_col > length(features)) n_col <- length(features)
if (is.list(shap_contrib)) { # multiclass: either choose a class or merge
shap_contrib <- if (!is.null(target_class)) shap_contrib[[target_class + 1]]
else Reduce("+", lapply(shap_contrib, abs))
}
shap_contrib <- shap_contrib[, features, drop = FALSE]
data <- data[, features, drop = FALSE]
cols <- colnames(data)
if (is.null(cols)) cols <- colnames(shap_contrib)
if (is.null(cols)) cols <- paste0('X', seq_len(ncol(data)))
colnames(data) <- cols
colnames(shap_contrib) <- cols
if (plot && which == "1d") {
op <- par(mfrow = c(ceiling(length(features) / n_col), n_col),
oma = c(0, 0, 0, 0) + 0.2,
mar = c(3.5, 3.5, 0, 0) + 0.1,
mgp = c(1.7, 0.6, 0))
for (f in cols) {
for (f in features) {
ord <- order(data[, f])
x <- data[, f][ord]
y <- shap_contrib[, f][ord]
@ -216,3 +182,105 @@ xgb.plot.shap <- function(data, shap_contrib = NULL, features = NULL, top_n = 1,
}
invisible(list(data = data, shap_contrib = shap_contrib))
}
#' SHAP contribution dependency summary plot
#'
#' Compare SHAP contributions of different features.
#'
#' A point plot (each point representing one sample from \code{data}) is
#' produced for each feature, with the points plotted on the SHAP value axis.
#' Each point (observation) is coloured based on its feature value. The plot
#' hence allows us to see which features have a negative / positive contribution
#' on the model prediction, and whether the contribution is different for larger
#' or smaller values of the feature. We effectively try to replicate the
#' \code{summary_plot} function from https://github.com/slundberg/shap.
#'
#' @inheritParams xgb.plot.shap
#'
#' @return A \code{ggplot2} object.
#' @export
#'
#' @examples See \code{\link{xgb.plot.shap}}.
#' @seealso \code{\link{xgb.plot.shap}}, \code{\link{xgb.ggplot.shap.summary}},
#' \code{\url{https://github.com/slundberg/shap}}
xgb.plot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, top_n = 10, model = NULL,
trees = NULL, target_class = NULL, approxcontrib = FALSE, subsample = NULL) {
# Only ggplot implementation is available.
xgb.ggplot.shap.summary(data, shap_contrib, features, top_n, model, trees, target_class, approxcontrib, subsample)
}
#' Prepare data for SHAP plots. To be used in xgb.plot.shap, xgb.plot.shap.summary, etc.
#' Internal utility function.
#'
#' @return A list containing: 'data', a matrix containing sample observations
#' and their feature values; 'shap_contrib', a matrix containing the SHAP contribution
#' values for these observations.
xgb.shap.data <- function(data, shap_contrib = NULL, features = NULL, top_n = 1, model = NULL,
trees = NULL, target_class = NULL, approxcontrib = FALSE,
subsample = NULL, max_observations = 100000) {
if (!is.matrix(data) && !inherits(data, "dgCMatrix"))
stop("data: must be either matrix or dgCMatrix")
if (is.null(shap_contrib) && (is.null(model) || !inherits(model, "xgb.Booster")))
stop("when shap_contrib is not provided, one must provide an xgb.Booster model")
if (is.null(features) && (is.null(model) || !inherits(model, "xgb.Booster")))
stop("when features are not provided, one must provide an xgb.Booster model to rank the features")
if (!is.null(shap_contrib) &&
(!is.matrix(shap_contrib) || nrow(shap_contrib) != nrow(data) || ncol(shap_contrib) != ncol(data) + 1))
stop("shap_contrib is not compatible with the provided data")
if (is.character(features) && is.null(colnames(data)))
stop("either provide `data` with column names or provide `features` as column indices")
if (is.null(model$feature_names) && model$nfeatures != ncol(data))
stop("if model has no feature_names, columns in `data` must match features in model")
if (!is.null(subsample)) {
idx <- sample(x = seq_len(nrow(data)), size = as.integer(subsample * nrow(data)), replace = FALSE)
} else {
idx <- seq_len(min(nrow(data), max_observations))
}
data <- data[idx, ]
if (is.null(colnames(data))) {
colnames(data) <- paste0("X", seq_len(ncol(data)))
}
if (!is.null(shap_contrib)) {
if (is.list(shap_contrib)) { # multiclass: either choose a class or merge
shap_contrib <- if (!is.null(target_class)) shap_contrib[[target_class + 1]] else Reduce("+", lapply(shap_contrib, abs))
}
shap_contrib <- shap_contrib[idx, ]
if (is.null(colnames(shap_contrib))) {
colnames(shap_contrib) <- paste0("X", seq_len(ncol(data)))
}
} else {
shap_contrib <- predict(model, newdata = data, predcontrib = TRUE, approxcontrib = approxcontrib)
if (is.list(shap_contrib)) { # multiclass: either choose a class or merge
shap_contrib <- if (!is.null(target_class)) shap_contrib[[target_class + 1]] else Reduce("+", lapply(shap_contrib, abs))
}
}
if (is.null(features)) {
if (!is.null(model$feature_names)) {
imp <- xgb.importance(model = model, trees = trees)
} else {
imp <- xgb.importance(model = model, trees = trees, feature_names = colnames(data))
}
top_n <- top_n[1]
if (top_n < 1 | top_n > 100) stop("top_n: must be an integer within [1, 100]")
features <- imp$Feature[1:min(top_n, NROW(imp))]
}
if (is.character(features)) {
features <- match(features, colnames(data))
}
shap_contrib <- shap_contrib[, features, drop = FALSE]
data <- data[, features, drop = FALSE]
list(
data = data,
shap_contrib = shap_contrib
)
}

View File

@ -351,11 +351,47 @@ test_that("xgb.plot.deepness works", {
xgb.ggplot.deepness(model = bst.Tree)
})
test_that("xgb.shap.data works when top_n is provided", {
data_list <- xgb.shap.data(data = sparse_matrix, model = bst.Tree, top_n = 2)
expect_equal(names(data_list), c("data", "shap_contrib"))
expect_equal(NCOL(data_list$data), 2)
expect_equal(NCOL(data_list$shap_contrib), 2)
expect_equal(NROW(data_list$data), NROW(data_list$shap_contrib))
expect_gt(length(colnames(data_list$data)), 0)
expect_gt(length(colnames(data_list$shap_contrib)), 0)
# for multiclass without target class provided
data_list <- xgb.shap.data(data = as.matrix(iris[, -5]), model = mbst.Tree, top_n = 2)
expect_equal(dim(data_list$shap_contrib), c(nrow(iris), 2))
# for multiclass with target class provided
data_list <- xgb.shap.data(data = as.matrix(iris[, -5]), model = mbst.Tree, top_n = 2, target_class = 0)
expect_equal(dim(data_list$shap_contrib), c(nrow(iris), 2))
})
test_that("xgb.shap.data works with subsampling", {
data_list <- xgb.shap.data(data = sparse_matrix, model = bst.Tree, top_n = 2, subsample = 0.8)
expect_equal(NROW(data_list$data), as.integer(0.8 * nrow(sparse_matrix)))
expect_equal(NROW(data_list$data), NROW(data_list$shap_contrib))
})
test_that("prepare.ggplot.shap.data works", {
data_list <- xgb.shap.data(data = sparse_matrix, model = bst.Tree, top_n = 2)
plot_data <- prepare.ggplot.shap.data(data_list, normalize = TRUE)
expect_s3_class(plot_data, "data.frame")
expect_equal(names(plot_data), c("id", "feature", "feature_value", "shap_value"))
expect_s3_class(plot_data$feature, "factor")
# Each observation should have 1 row for each feature
expect_equal(nrow(plot_data), nrow(sparse_matrix) * 2)
})
test_that("xgb.plot.shap works", {
sh <- xgb.plot.shap(data = sparse_matrix, model = bst.Tree, top_n = 2, col = 4)
expect_equal(names(sh), c("data", "shap_contrib"))
expect_equal(NCOL(sh$data), 2)
expect_equal(NCOL(sh$shap_contrib), 2)
})
test_that("xgb.plot.shap.summary works", {
xgb.plot.shap.summary(data = sparse_matrix, model = bst.Tree, top_n = 2)
xgb.ggplot.shap.summary(data = sparse_matrix, model = bst.Tree, top_n = 2)
})
test_that("check.deprecation works", {