[R] Allow passing data.frame to SHAP (#10744)
This commit is contained in:
@@ -102,6 +102,27 @@ xgb.ggplot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med
|
||||
#' @export
|
||||
xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, top_n = 10, model = NULL,
|
||||
trees = NULL, target_class = NULL, approxcontrib = FALSE, subsample = NULL) {
|
||||
if (inherits(data, "xgb.DMatrix")) {
|
||||
stop(
|
||||
"'xgb.ggplot.shap.summary' is not compatible with 'xgb.DMatrix' objects. Try passing a matrix or data.frame."
|
||||
)
|
||||
}
|
||||
cols_categ <- NULL
|
||||
if (!is.null(model)) {
|
||||
ftypes <- getinfo(model, "feature_type")
|
||||
if (NROW(ftypes)) {
|
||||
if (length(ftypes) != ncol(data)) {
|
||||
stop(sprintf("'data' has incorrect number of columns (expected: %d, got: %d).", length(ftypes), ncol(data)))
|
||||
}
|
||||
cols_categ <- colnames(data)[ftypes == "c"]
|
||||
}
|
||||
} else if (inherits(data, "data.frame")) {
|
||||
cols_categ <- names(data)[sapply(data, function(x) is.factor(x) || is.character(x))]
|
||||
}
|
||||
if (NROW(cols_categ)) {
|
||||
warning("Categorical features are ignored in 'xgb.ggplot.shap.summary'.")
|
||||
}
|
||||
|
||||
data_list <- xgb.shap.data(
|
||||
data = data,
|
||||
shap_contrib = shap_contrib,
|
||||
@@ -114,6 +135,10 @@ xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL,
|
||||
subsample = subsample,
|
||||
max_observations = 10000 # 10,000 samples per feature.
|
||||
)
|
||||
if (NROW(cols_categ)) {
|
||||
data_list <- lapply(data_list, function(x) x[, !(colnames(x) %in% cols_categ), drop = FALSE])
|
||||
}
|
||||
|
||||
p_data <- prepare.ggplot.shap.data(data_list, normalize = TRUE)
|
||||
# Reverse factor levels so that the first level is at the top of the plot
|
||||
p_data[, "feature" := factor(feature, rev(levels(feature)))]
|
||||
@@ -134,7 +159,8 @@ xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL,
|
||||
#' @param data_list The result of `xgb.shap.data()`.
|
||||
#' @param normalize Whether to standardize feature values to mean 0 and
|
||||
#' standard deviation 1. This is useful for comparing multiple features on the same
|
||||
#' plot. Default is `FALSE`.
|
||||
#' plot. Default is `FALSE`. Note that it cannot be used when the data contains
|
||||
#' categorical features.
|
||||
#' @return A `data.table` containing the observation ID, the feature name, the
|
||||
#' feature value (normalized if specified), and the SHAP contribution value.
|
||||
#' @noRd
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#'
|
||||
#' Visualizes SHAP values against feature values to gain an impression of feature effects.
|
||||
#'
|
||||
#' @param data The data to explain as a `matrix` or `dgCMatrix`.
|
||||
#' @param data The data to explain as a `matrix`, `dgCMatrix`, or `data.frame`.
|
||||
#' @param shap_contrib Matrix of SHAP contributions of `data`.
|
||||
#' The default (`NULL`) computes it from `model` and `data`.
|
||||
#' @param features Vector of column indices or feature names to plot. When `NULL`
|
||||
@@ -285,8 +285,11 @@ xgb.plot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, to
|
||||
xgb.shap.data <- function(data, shap_contrib = NULL, features = NULL, top_n = 1, model = NULL,
|
||||
trees = NULL, target_class = NULL, approxcontrib = FALSE,
|
||||
subsample = NULL, max_observations = 100000) {
|
||||
if (!is.matrix(data) && !inherits(data, "dgCMatrix"))
|
||||
stop("data: must be either matrix or dgCMatrix")
|
||||
if (!inherits(data, c("matrix", "dsparseMatrix", "data.frame")))
|
||||
stop("data: must be matrix, sparse matrix, or data.frame.")
|
||||
if (inherits(data, "data.frame") && length(class(data)) > 1L) {
|
||||
data <- as.data.frame(data)
|
||||
}
|
||||
|
||||
if (is.null(shap_contrib) && (is.null(model) || !inherits(model, "xgb.Booster")))
|
||||
stop("when shap_contrib is not provided, one must provide an xgb.Booster model")
|
||||
@@ -311,7 +314,14 @@ xgb.shap.data <- function(data, shap_contrib = NULL, features = NULL, top_n = 1,
|
||||
stop("if model has no feature_names, columns in `data` must match features in model")
|
||||
|
||||
if (!is.null(subsample)) {
|
||||
idx <- sample(x = seq_len(nrow(data)), size = as.integer(subsample * nrow(data)), replace = FALSE)
|
||||
if (subsample <= 0 || subsample >= 1) {
|
||||
stop("'subsample' must be a number between zero and one (non-inclusive).")
|
||||
}
|
||||
sample_size <- as.integer(subsample * nrow(data))
|
||||
if (sample_size < 2) {
|
||||
stop("Sampling fraction involves less than 2 rows.")
|
||||
}
|
||||
idx <- sample(x = seq_len(nrow(data)), size = sample_size, replace = FALSE)
|
||||
} else {
|
||||
idx <- seq_len(min(nrow(data), max_observations))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user