[R] Allow passing data.frame to SHAP (#10744)
This commit is contained in:
parent
ec8cfb3267
commit
f52f11e1d7
@ -102,6 +102,27 @@ xgb.ggplot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med
|
|||||||
#' @export
|
#' @export
|
||||||
xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, top_n = 10, model = NULL,
|
xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, top_n = 10, model = NULL,
|
||||||
trees = NULL, target_class = NULL, approxcontrib = FALSE, subsample = NULL) {
|
trees = NULL, target_class = NULL, approxcontrib = FALSE, subsample = NULL) {
|
||||||
|
if (inherits(data, "xgb.DMatrix")) {
|
||||||
|
stop(
|
||||||
|
"'xgb.ggplot.shap.summary' is not compatible with 'xgb.DMatrix' objects. Try passing a matrix or data.frame."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
cols_categ <- NULL
|
||||||
|
if (!is.null(model)) {
|
||||||
|
ftypes <- getinfo(model, "feature_type")
|
||||||
|
if (NROW(ftypes)) {
|
||||||
|
if (length(ftypes) != ncol(data)) {
|
||||||
|
stop(sprintf("'data' has incorrect number of columns (expected: %d, got: %d).", length(ftypes), ncol(data)))
|
||||||
|
}
|
||||||
|
cols_categ <- colnames(data)[ftypes == "c"]
|
||||||
|
}
|
||||||
|
} else if (inherits(data, "data.frame")) {
|
||||||
|
cols_categ <- names(data)[sapply(data, function(x) is.factor(x) || is.character(x))]
|
||||||
|
}
|
||||||
|
if (NROW(cols_categ)) {
|
||||||
|
warning("Categorical features are ignored in 'xgb.ggplot.shap.summary'.")
|
||||||
|
}
|
||||||
|
|
||||||
data_list <- xgb.shap.data(
|
data_list <- xgb.shap.data(
|
||||||
data = data,
|
data = data,
|
||||||
shap_contrib = shap_contrib,
|
shap_contrib = shap_contrib,
|
||||||
@ -114,6 +135,10 @@ xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL,
|
|||||||
subsample = subsample,
|
subsample = subsample,
|
||||||
max_observations = 10000 # 10,000 samples per feature.
|
max_observations = 10000 # 10,000 samples per feature.
|
||||||
)
|
)
|
||||||
|
if (NROW(cols_categ)) {
|
||||||
|
data_list <- lapply(data_list, function(x) x[, !(colnames(x) %in% cols_categ), drop = FALSE])
|
||||||
|
}
|
||||||
|
|
||||||
p_data <- prepare.ggplot.shap.data(data_list, normalize = TRUE)
|
p_data <- prepare.ggplot.shap.data(data_list, normalize = TRUE)
|
||||||
# Reverse factor levels so that the first level is at the top of the plot
|
# Reverse factor levels so that the first level is at the top of the plot
|
||||||
p_data[, "feature" := factor(feature, rev(levels(feature)))]
|
p_data[, "feature" := factor(feature, rev(levels(feature)))]
|
||||||
@ -134,7 +159,8 @@ xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL,
|
|||||||
#' @param data_list The result of `xgb.shap.data()`.
|
#' @param data_list The result of `xgb.shap.data()`.
|
||||||
#' @param normalize Whether to standardize feature values to mean 0 and
|
#' @param normalize Whether to standardize feature values to mean 0 and
|
||||||
#' standard deviation 1. This is useful for comparing multiple features on the same
|
#' standard deviation 1. This is useful for comparing multiple features on the same
|
||||||
#' plot. Default is `FALSE`.
|
#' plot. Default is `FALSE`. Note that it cannot be used when the data contains
|
||||||
|
#' categorical features.
|
||||||
#' @return A `data.table` containing the observation ID, the feature name, the
|
#' @return A `data.table` containing the observation ID, the feature name, the
|
||||||
#' feature value (normalized if specified), and the SHAP contribution value.
|
#' feature value (normalized if specified), and the SHAP contribution value.
|
||||||
#' @noRd
|
#' @noRd
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
#'
|
#'
|
||||||
#' Visualizes SHAP values against feature values to gain an impression of feature effects.
|
#' Visualizes SHAP values against feature values to gain an impression of feature effects.
|
||||||
#'
|
#'
|
||||||
#' @param data The data to explain as a `matrix` or `dgCMatrix`.
|
#' @param data The data to explain as a `matrix`, `dgCMatrix`, or `data.frame`.
|
||||||
#' @param shap_contrib Matrix of SHAP contributions of `data`.
|
#' @param shap_contrib Matrix of SHAP contributions of `data`.
|
||||||
#' The default (`NULL`) computes it from `model` and `data`.
|
#' The default (`NULL`) computes it from `model` and `data`.
|
||||||
#' @param features Vector of column indices or feature names to plot. When `NULL`
|
#' @param features Vector of column indices or feature names to plot. When `NULL`
|
||||||
@ -285,8 +285,11 @@ xgb.plot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, to
|
|||||||
xgb.shap.data <- function(data, shap_contrib = NULL, features = NULL, top_n = 1, model = NULL,
|
xgb.shap.data <- function(data, shap_contrib = NULL, features = NULL, top_n = 1, model = NULL,
|
||||||
trees = NULL, target_class = NULL, approxcontrib = FALSE,
|
trees = NULL, target_class = NULL, approxcontrib = FALSE,
|
||||||
subsample = NULL, max_observations = 100000) {
|
subsample = NULL, max_observations = 100000) {
|
||||||
if (!is.matrix(data) && !inherits(data, "dgCMatrix"))
|
if (!inherits(data, c("matrix", "dsparseMatrix", "data.frame")))
|
||||||
stop("data: must be either matrix or dgCMatrix")
|
stop("data: must be matrix, sparse matrix, or data.frame.")
|
||||||
|
if (inherits(data, "data.frame") && length(class(data)) > 1L) {
|
||||||
|
data <- as.data.frame(data)
|
||||||
|
}
|
||||||
|
|
||||||
if (is.null(shap_contrib) && (is.null(model) || !inherits(model, "xgb.Booster")))
|
if (is.null(shap_contrib) && (is.null(model) || !inherits(model, "xgb.Booster")))
|
||||||
stop("when shap_contrib is not provided, one must provide an xgb.Booster model")
|
stop("when shap_contrib is not provided, one must provide an xgb.Booster model")
|
||||||
@ -311,7 +314,14 @@ xgb.shap.data <- function(data, shap_contrib = NULL, features = NULL, top_n = 1,
|
|||||||
stop("if model has no feature_names, columns in `data` must match features in model")
|
stop("if model has no feature_names, columns in `data` must match features in model")
|
||||||
|
|
||||||
if (!is.null(subsample)) {
|
if (!is.null(subsample)) {
|
||||||
idx <- sample(x = seq_len(nrow(data)), size = as.integer(subsample * nrow(data)), replace = FALSE)
|
if (subsample <= 0 || subsample >= 1) {
|
||||||
|
stop("'subsample' must be a number between zero and one (non-inclusive).")
|
||||||
|
}
|
||||||
|
sample_size <- as.integer(subsample * nrow(data))
|
||||||
|
if (sample_size < 2) {
|
||||||
|
stop("Sampling fraction involves less than 2 rows.")
|
||||||
|
}
|
||||||
|
idx <- sample(x = seq_len(nrow(data)), size = sample_size, replace = FALSE)
|
||||||
} else {
|
} else {
|
||||||
idx <- seq_len(min(nrow(data), max_observations))
|
idx <- seq_len(min(nrow(data), max_observations))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -33,7 +33,7 @@ xgb.plot.shap(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
\arguments{
|
\arguments{
|
||||||
\item{data}{The data to explain as a \code{matrix} or \code{dgCMatrix}.}
|
\item{data}{The data to explain as a \code{matrix}, \code{dgCMatrix}, or \code{data.frame}.}
|
||||||
|
|
||||||
\item{shap_contrib}{Matrix of SHAP contributions of \code{data}.
|
\item{shap_contrib}{Matrix of SHAP contributions of \code{data}.
|
||||||
The default (\code{NULL}) computes it from \code{model} and \code{data}.}
|
The default (\code{NULL}) computes it from \code{model} and \code{data}.}
|
||||||
|
|||||||
@ -30,7 +30,7 @@ xgb.plot.shap.summary(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
\arguments{
|
\arguments{
|
||||||
\item{data}{The data to explain as a \code{matrix} or \code{dgCMatrix}.}
|
\item{data}{The data to explain as a \code{matrix}, \code{dgCMatrix}, or \code{data.frame}.}
|
||||||
|
|
||||||
\item{shap_contrib}{Matrix of SHAP contributions of \code{data}.
|
\item{shap_contrib}{Matrix of SHAP contributions of \code{data}.
|
||||||
The default (\code{NULL}) computes it from \code{model} and \code{data}.}
|
The default (\code{NULL}) computes it from \code{model} and \code{data}.}
|
||||||
|
|||||||
@ -449,6 +449,26 @@ test_that("xgb.shap.data works with subsampling", {
|
|||||||
expect_equal(NROW(data_list$data), NROW(data_list$shap_contrib))
|
expect_equal(NROW(data_list$data), NROW(data_list$shap_contrib))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("xgb.shap.data works with data frames", {
|
||||||
|
data(mtcars)
|
||||||
|
df <- mtcars
|
||||||
|
df$cyl <- factor(df$cyl)
|
||||||
|
x <- df[, -1]
|
||||||
|
y <- df$mpg
|
||||||
|
dm <- xgb.DMatrix(x, label = y, nthread = 1L)
|
||||||
|
model <- xgb.train(
|
||||||
|
data = dm,
|
||||||
|
params = list(
|
||||||
|
max_depth = 2,
|
||||||
|
nthread = 1
|
||||||
|
),
|
||||||
|
nrounds = 2
|
||||||
|
)
|
||||||
|
data_list <- xgb.shap.data(data = df[, -1], model = model, top_n = 2, subsample = 0.8)
|
||||||
|
expect_equal(NROW(data_list$data), as.integer(0.8 * nrow(df)))
|
||||||
|
expect_equal(NROW(data_list$data), NROW(data_list$shap_contrib))
|
||||||
|
})
|
||||||
|
|
||||||
test_that("prepare.ggplot.shap.data works", {
|
test_that("prepare.ggplot.shap.data works", {
|
||||||
.skip_if_vcd_not_available()
|
.skip_if_vcd_not_available()
|
||||||
data_list <- xgb.shap.data(data = sparse_matrix, model = bst.Tree, top_n = 2)
|
data_list <- xgb.shap.data(data = sparse_matrix, model = bst.Tree, top_n = 2)
|
||||||
@ -472,6 +492,44 @@ test_that("xgb.plot.shap.summary works", {
|
|||||||
expect_silent(xgb.ggplot.shap.summary(data = sparse_matrix, model = bst.Tree, top_n = 2))
|
expect_silent(xgb.ggplot.shap.summary(data = sparse_matrix, model = bst.Tree, top_n = 2))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("xgb.plot.shap.summary ignores categorical features", {
|
||||||
|
.skip_if_vcd_not_available()
|
||||||
|
data(mtcars)
|
||||||
|
df <- mtcars
|
||||||
|
df$cyl <- factor(df$cyl)
|
||||||
|
levels(df$cyl) <- c("a", "b", "c")
|
||||||
|
x <- df[, -1]
|
||||||
|
y <- df$mpg
|
||||||
|
dm <- xgb.DMatrix(x, label = y, nthread = 1L)
|
||||||
|
model <- xgb.train(
|
||||||
|
data = dm,
|
||||||
|
params = list(
|
||||||
|
max_depth = 2,
|
||||||
|
nthread = 1
|
||||||
|
),
|
||||||
|
nrounds = 2
|
||||||
|
)
|
||||||
|
expect_warning({
|
||||||
|
xgb.ggplot.shap.summary(data = x, model = model, top_n = 2)
|
||||||
|
})
|
||||||
|
|
||||||
|
x_num <- mtcars[, -1]
|
||||||
|
x_num$gear <- as.numeric(x_num$gear) - 1
|
||||||
|
x_num <- as.matrix(x_num)
|
||||||
|
dm <- xgb.DMatrix(x_num, label = y, feature_types = c(rep("q", 8), "c", "q"), nthread = 1L)
|
||||||
|
model <- xgb.train(
|
||||||
|
data = dm,
|
||||||
|
params = list(
|
||||||
|
max_depth = 2,
|
||||||
|
nthread = 1
|
||||||
|
),
|
||||||
|
nrounds = 2
|
||||||
|
)
|
||||||
|
expect_warning({
|
||||||
|
xgb.ggplot.shap.summary(data = x_num, model = model, top_n = 2)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
test_that("check.deprecation works", {
|
test_that("check.deprecation works", {
|
||||||
ttt <- function(a = NNULL, DUMMY = NULL, ...) {
|
ttt <- function(a = NNULL, DUMMY = NULL, ...) {
|
||||||
check.deprecation(...)
|
check.deprecation(...)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user