diff --git a/R-package/R/xgb.ggplot.R b/R-package/R/xgb.ggplot.R index 3b76e9fac..339e0fac1 100644 --- a/R-package/R/xgb.ggplot.R +++ b/R-package/R/xgb.ggplot.R @@ -99,6 +99,85 @@ xgb.ggplot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med } } +#' @rdname xgb.plot.shap.summary +#' @export +xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, top_n = 10, model = NULL, + trees = NULL, target_class = NULL, approxcontrib = FALSE, subsample = NULL) { + data_list <- xgb.shap.data( + data = data, + shap_contrib = shap_contrib, + features = features, + top_n = top_n, + model = model, + trees = trees, + target_class = target_class, + approxcontrib = approxcontrib, + subsample = subsample, + max_observations = 10000 # 10,000 samples per feature. + ) + p_data <- prepare.ggplot.shap.data(data_list, normalize = TRUE) + # Reverse factor levels so that the first level is at the top of the plot + p_data[, "feature" := factor(feature, rev(levels(feature)))] + + p <- ggplot2::ggplot(p_data, ggplot2::aes(x = feature, y = shap_value, colour = feature_value)) + + ggplot2::geom_jitter(alpha = 0.5, width = 0.1) + + ggplot2::scale_colour_viridis_c(limits = c(-3, 3), option = "plasma", direction = -1) + + ggplot2::geom_abline(slope = 0, intercept = 0, colour = "darkgrey") + + ggplot2::coord_flip() + + p +} + +#' Combine and melt feature values and SHAP contributions for sample +#' observations. +#' +#' Conforms to data format required for ggplot functions. +#' +#' Internal utility function. +#' +#' @param data_list List containing 'data' and 'shap_contrib' returned by +#' \code{xgb.shap.data()}. +#' @param normalize Whether to standardize feature values to have mean 0 and +#' standard deviation 1 (useful for comparing multiple features on the same +#' plot). Default \code{FALSE}. +#' +#' @return A data.table containing the observation ID, the feature name, the +#' feature value (normalized if specified), and the SHAP contribution value. +prepare.ggplot.shap.data <- function(data_list, normalize = FALSE) { + data <- data_list[["data"]] + shap_contrib <- data_list[["shap_contrib"]] + + data <- data.table::as.data.table(as.matrix(data)) + if (normalize) { + data[, (names(data)) := lapply(.SD, normalize)] + } + data[, "id" := seq_len(nrow(data))] + data_m <- data.table::melt.data.table(data, id.vars = "id", variable.name = "feature", value.name = "feature_value") + + shap_contrib <- data.table::as.data.table(as.matrix(shap_contrib)) + shap_contrib[, "id" := seq_len(nrow(shap_contrib))] + shap_contrib_m <- data.table::melt.data.table(shap_contrib, id.vars = "id", variable.name = "feature", value.name = "shap_value") + + p_data <- data.table::merge.data.table(data_m, shap_contrib_m, by = c("id", "feature")) + + p_data +} + +#' Scale feature value to have mean 0, standard deviation 1 +#' +#' This is used to compare multiple features on the same plot. +#' Internal utility function +#' +#' @param x Numeric vector +#' +#' @return Numeric vector with mean 0 and sd 1. +normalize <- function(x) { + loc <- mean(x, na.rm = TRUE) + scale <- stats::sd(x, na.rm = TRUE) + + (x - loc) / scale +} + # Plot multiple ggplot graph aligned by rows and columns. # ... the plots # cols number of columns @@ -131,5 +210,5 @@ multiplot <- function(..., cols = 1) { globalVariables(c( "Cluster", "ggplot", "aes", "geom_bar", "coord_flip", "xlab", "ylab", "ggtitle", "theme", - "element_blank", "element_text", "V1", "Weight" + "element_blank", "element_text", "V1", "Weight", "feature" )) diff --git a/R-package/R/xgb.plot.shap.R b/R-package/R/xgb.plot.shap.R index a44d4b570..d9ea69786 100644 --- a/R-package/R/xgb.plot.shap.R +++ b/R-package/R/xgb.plot.shap.R @@ -81,6 +81,7 @@ #' xgb.plot.shap(agaricus.test$data, model = bst, features = "odor=none") #' contr <- predict(bst, agaricus.test$data, predcontrib = TRUE) #' xgb.plot.shap(agaricus.test$data, contr, model = bst, top_n = 12, n_col = 3) +#' xgb.ggplot.shap.summary(agaricus.test$data, contr, model = bst, top_n = 12) # Summary plot #' #' # multiclass example - plots for each class separately: #' nclass <- 3 @@ -99,6 +100,7 @@ #' n_col = 2, col = col, pch = 16, pch_NA = 17) #' xgb.plot.shap(x, model = mbst, trees = trees0 + 2, target_class = 2, top_n = 4, #' n_col = 2, col = col, pch = 16, pch_NA = 17) +#' xgb.ggplot.shap.summary(x, model = mbst, target_class = 0, top_n = 4) # Summary plot #' #' @rdname xgb.plot.shap #' @export @@ -109,69 +111,33 @@ xgb.plot.shap <- function(data, shap_contrib = NULL, features = NULL, top_n = 1, plot_NA = TRUE, col_NA = rgb(0.7, 0, 1, 0.6), pch_NA = '.', pos_NA = 1.07, plot_loess = TRUE, col_loess = 2, span_loess = 0.5, which = c("1d", "2d"), plot = TRUE, ...) { - - if (!is.matrix(data) && !inherits(data, "dgCMatrix")) - stop("data: must be either matrix or dgCMatrix") - - if (is.null(shap_contrib) && (is.null(model) || !inherits(model, "xgb.Booster"))) - stop("when shap_contrib is not provided, one must provide an xgb.Booster model") - - if (is.null(features) && (is.null(model) || !inherits(model, "xgb.Booster"))) - stop("when features are not provided, one must provide an xgb.Booster model to rank the features") - - if (!is.null(shap_contrib) && - (!is.matrix(shap_contrib) || nrow(shap_contrib) != nrow(data) || ncol(shap_contrib) != ncol(data) + 1)) - stop("shap_contrib is not compatible with the provided data") - - nsample <- if (is.null(subsample)) min(100000, nrow(data)) else as.integer(subsample * nrow(data)) - idx <- sample(seq_len(nrow(data)), nsample) - data <- data[idx, ] - - if (is.null(shap_contrib)) { - shap_contrib <- predict(model, data, predcontrib = TRUE, approxcontrib = approxcontrib) - } else { - shap_contrib <- shap_contrib[idx, ] - } + data_list <- xgb.shap.data( + data = data, + shap_contrib = shap_contrib, + features = features, + top_n = top_n, + model = model, + trees = trees, + target_class = target_class, + approxcontrib = approxcontrib, + subsample = subsample, + max_observations = 100000 + ) + data <- data_list[["data"]] + shap_contrib <- data_list[["shap_contrib"]] + features <- colnames(data) which <- match.arg(which) if (which == "2d") stop("2D plots are not implemented yet") - if (is.null(features)) { - imp <- xgb.importance(model = model, trees = trees) - top_n <- as.integer(top_n[1]) - if (top_n < 1 && top_n > 100) - stop("top_n: must be an integer within [1, 100]") - features <- imp$Feature[1:min(top_n, NROW(imp))] - } - - if (is.character(features)) { - if (is.null(colnames(data))) - stop("Either provide `data` with column names or provide `features` as column indices") - features <- match(features, colnames(data)) - } - if (n_col > length(features)) n_col <- length(features) - - if (is.list(shap_contrib)) { # multiclass: either choose a class or merge - shap_contrib <- if (!is.null(target_class)) shap_contrib[[target_class + 1]] - else Reduce("+", lapply(shap_contrib, abs)) - } - - shap_contrib <- shap_contrib[, features, drop = FALSE] - data <- data[, features, drop = FALSE] - cols <- colnames(data) - if (is.null(cols)) cols <- colnames(shap_contrib) - if (is.null(cols)) cols <- paste0('X', seq_len(ncol(data))) - colnames(data) <- cols - colnames(shap_contrib) <- cols - if (plot && which == "1d") { op <- par(mfrow = c(ceiling(length(features) / n_col), n_col), oma = c(0, 0, 0, 0) + 0.2, mar = c(3.5, 3.5, 0, 0) + 0.1, mgp = c(1.7, 0.6, 0)) - for (f in cols) { + for (f in features) { ord <- order(data[, f]) x <- data[, f][ord] y <- shap_contrib[, f][ord] @@ -216,3 +182,105 @@ xgb.plot.shap <- function(data, shap_contrib = NULL, features = NULL, top_n = 1, } invisible(list(data = data, shap_contrib = shap_contrib)) } + +#' SHAP contribution dependency summary plot +#' +#' Compare SHAP contributions of different features. +#' +#' A point plot (each point representing one sample from \code{data}) is +#' produced for each feature, with the points plotted on the SHAP value axis. +#' Each point (observation) is coloured based on its feature value. The plot +#' hence allows us to see which features have a negative / positive contribution +#' on the model prediction, and whether the contribution is different for larger +#' or smaller values of the feature. We effectively try to replicate the +#' \code{summary_plot} function from https://github.com/slundberg/shap. +#' +#' @inheritParams xgb.plot.shap +#' +#' @return A \code{ggplot2} object. +#' @export +#' +#' @examples See \code{\link{xgb.plot.shap}}. +#' @seealso \code{\link{xgb.plot.shap}}, \code{\link{xgb.ggplot.shap.summary}}, +#' \code{\url{https://github.com/slundberg/shap}} +xgb.plot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, top_n = 10, model = NULL, + trees = NULL, target_class = NULL, approxcontrib = FALSE, subsample = NULL) { + # Only ggplot implementation is available. + xgb.ggplot.shap.summary(data, shap_contrib, features, top_n, model, trees, target_class, approxcontrib, subsample) +} + +#' Prepare data for SHAP plots. To be used in xgb.plot.shap, xgb.plot.shap.summary, etc. +#' Internal utility function. +#' +#' @return A list containing: 'data', a matrix containing sample observations +#' and their feature values; 'shap_contrib', a matrix containing the SHAP contribution +#' values for these observations. +xgb.shap.data <- function(data, shap_contrib = NULL, features = NULL, top_n = 1, model = NULL, + trees = NULL, target_class = NULL, approxcontrib = FALSE, + subsample = NULL, max_observations = 100000) { + if (!is.matrix(data) && !inherits(data, "dgCMatrix")) + stop("data: must be either matrix or dgCMatrix") + + if (is.null(shap_contrib) && (is.null(model) || !inherits(model, "xgb.Booster"))) + stop("when shap_contrib is not provided, one must provide an xgb.Booster model") + + if (is.null(features) && (is.null(model) || !inherits(model, "xgb.Booster"))) + stop("when features are not provided, one must provide an xgb.Booster model to rank the features") + + if (!is.null(shap_contrib) && + (!is.matrix(shap_contrib) || nrow(shap_contrib) != nrow(data) || ncol(shap_contrib) != ncol(data) + 1)) + stop("shap_contrib is not compatible with the provided data") + + if (is.character(features) && is.null(colnames(data))) + stop("either provide `data` with column names or provide `features` as column indices") + + if (is.null(model$feature_names) && model$nfeatures != ncol(data)) + stop("if model has no feature_names, columns in `data` must match features in model") + + if (!is.null(subsample)) { + idx <- sample(x = seq_len(nrow(data)), size = as.integer(subsample * nrow(data)), replace = FALSE) + } else { + idx <- seq_len(min(nrow(data), max_observations)) + } + data <- data[idx, ] + if (is.null(colnames(data))) { + colnames(data) <- paste0("X", seq_len(ncol(data))) + } + + if (!is.null(shap_contrib)) { + if (is.list(shap_contrib)) { # multiclass: either choose a class or merge + shap_contrib <- if (!is.null(target_class)) shap_contrib[[target_class + 1]] else Reduce("+", lapply(shap_contrib, abs)) + } + shap_contrib <- shap_contrib[idx, ] + if (is.null(colnames(shap_contrib))) { + colnames(shap_contrib) <- paste0("X", seq_len(ncol(data))) + } + } else { + shap_contrib <- predict(model, newdata = data, predcontrib = TRUE, approxcontrib = approxcontrib) + if (is.list(shap_contrib)) { # multiclass: either choose a class or merge + shap_contrib <- if (!is.null(target_class)) shap_contrib[[target_class + 1]] else Reduce("+", lapply(shap_contrib, abs)) + } + } + + if (is.null(features)) { + if (!is.null(model$feature_names)) { + imp <- xgb.importance(model = model, trees = trees) + } else { + imp <- xgb.importance(model = model, trees = trees, feature_names = colnames(data)) + } + top_n <- top_n[1] + if (top_n < 1 | top_n > 100) stop("top_n: must be an integer within [1, 100]") + features <- imp$Feature[1:min(top_n, NROW(imp))] + } + if (is.character(features)) { + features <- match(features, colnames(data)) + } + + shap_contrib <- shap_contrib[, features, drop = FALSE] + data <- data[, features, drop = FALSE] + + list( + data = data, + shap_contrib = shap_contrib + ) +} diff --git a/R-package/tests/testthat/test_helpers.R b/R-package/tests/testthat/test_helpers.R index 2ee1acf56..86c0efd02 100644 --- a/R-package/tests/testthat/test_helpers.R +++ b/R-package/tests/testthat/test_helpers.R @@ -351,11 +351,47 @@ test_that("xgb.plot.deepness works", { xgb.ggplot.deepness(model = bst.Tree) }) +test_that("xgb.shap.data works when top_n is provided", { + data_list <- xgb.shap.data(data = sparse_matrix, model = bst.Tree, top_n = 2) + expect_equal(names(data_list), c("data", "shap_contrib")) + expect_equal(NCOL(data_list$data), 2) + expect_equal(NCOL(data_list$shap_contrib), 2) + expect_equal(NROW(data_list$data), NROW(data_list$shap_contrib)) + expect_gt(length(colnames(data_list$data)), 0) + expect_gt(length(colnames(data_list$shap_contrib)), 0) + + # for multiclass without target class provided + data_list <- xgb.shap.data(data = as.matrix(iris[, -5]), model = mbst.Tree, top_n = 2) + expect_equal(dim(data_list$shap_contrib), c(nrow(iris), 2)) + # for multiclass with target class provided + data_list <- xgb.shap.data(data = as.matrix(iris[, -5]), model = mbst.Tree, top_n = 2, target_class = 0) + expect_equal(dim(data_list$shap_contrib), c(nrow(iris), 2)) +}) + +test_that("xgb.shap.data works with subsampling", { + data_list <- xgb.shap.data(data = sparse_matrix, model = bst.Tree, top_n = 2, subsample = 0.8) + expect_equal(NROW(data_list$data), as.integer(0.8 * nrow(sparse_matrix))) + expect_equal(NROW(data_list$data), NROW(data_list$shap_contrib)) +}) + +test_that("prepare.ggplot.shap.data works", { + data_list <- xgb.shap.data(data = sparse_matrix, model = bst.Tree, top_n = 2) + plot_data <- prepare.ggplot.shap.data(data_list, normalize = TRUE) + expect_s3_class(plot_data, "data.frame") + expect_equal(names(plot_data), c("id", "feature", "feature_value", "shap_value")) + expect_s3_class(plot_data$feature, "factor") + # Each observation should have 1 row for each feature + expect_equal(nrow(plot_data), nrow(sparse_matrix) * 2) +}) + test_that("xgb.plot.shap works", { sh <- xgb.plot.shap(data = sparse_matrix, model = bst.Tree, top_n = 2, col = 4) expect_equal(names(sh), c("data", "shap_contrib")) - expect_equal(NCOL(sh$data), 2) - expect_equal(NCOL(sh$shap_contrib), 2) +}) + +test_that("xgb.plot.shap.summary works", { + xgb.plot.shap.summary(data = sparse_matrix, model = bst.Tree, top_n = 2) + xgb.ggplot.shap.summary(data = sparse_matrix, model = bst.Tree, top_n = 2) }) test_that("check.deprecation works", {