[R] R-interface for SHAP interactions (#3636)
* add R-interface for SHAP interactions * update docs for new roxygen version
This commit is contained in:
committed by
GitHub
parent
10c31ab2cb
commit
5b662cbe1c
@@ -129,11 +129,13 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
|
||||
#' logistic regression would result in predictions for log-odds instead of probabilities.
|
||||
#' @param ntreelimit limit the number of model's trees or boosting iterations used in prediction (see Details).
|
||||
#' It will use all the trees by default (\code{NULL} value).
|
||||
#' @param predleaf whether predict leaf index instead.
|
||||
#' @param predcontrib whether to return feature contributions to individual predictions instead (see Details).
|
||||
#' @param predleaf whether predict leaf index.
|
||||
#' @param predcontrib whether to return feature contributions to individual predictions (see Details).
|
||||
#' @param approxcontrib whether to use a fast approximation for feature contributions (see Details).
|
||||
#' @param predinteraction whether to return contributions of feature interactions to individual predictions (see Details).
|
||||
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
|
||||
#' prediction outputs per case. This option has no effect when \code{predleaf = TRUE}.
|
||||
#' prediction outputs per case. This option has no effect when either of predleaf, predcontrib,
|
||||
#' or predinteraction flags is TRUE.
|
||||
#' @param ... Parameters passed to \code{predict.xgb.Booster}
|
||||
#'
|
||||
#' @details
|
||||
@@ -158,6 +160,11 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
|
||||
#' Setting \code{approxcontrib = TRUE} approximates these values following the idea explained
|
||||
#' in \url{http://blog.datadive.net/interpreting-random-forests/}.
|
||||
#'
|
||||
#' With \code{predinteraction = TRUE}, SHAP values of contributions of interaction of each pair of features
|
||||
#' are computed. Note that this operation might be rather expensive in terms of compute and memory.
|
||||
#' Since it quadratically depends on the number of features, it is recommended to perfom selection
|
||||
#' of the most important features first. See below about the format of the returned results.
|
||||
#'
|
||||
#' @return
|
||||
#' For regression or binary classification, it returns a vector of length \code{nrows(newdata)}.
|
||||
#' For multiclass classification, either a \code{num_class * nrows(newdata)} vector or
|
||||
@@ -173,6 +180,14 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
|
||||
#' such a matrix. The contribution values are on the scale of untransformed margin
|
||||
#' (e.g., for binary classification would mean that the contributions are log-odds deviations from bias).
|
||||
#'
|
||||
#' When \code{predinteraction = TRUE} and it is not a multiclass setting, the output is a 3d array with
|
||||
#' dimensions \code{c(nrow, num_features + 1, num_features + 1)}. The off-diagonal (in the last two dimensions)
|
||||
#' elements represent different features interaction contributions. The array is symmetric WRT the last
|
||||
#' two dimensions. The "+ 1" columns corresponds to bias. Summing this array along the last dimension should
|
||||
#' produce practically the same result as predict with \code{predcontrib = TRUE}.
|
||||
#' For a multiclass case, a list of \code{num_class} elements is returned, where each element is
|
||||
#' such an array.
|
||||
#'
|
||||
#' @seealso
|
||||
#' \code{\link{xgb.train}}.
|
||||
#'
|
||||
@@ -269,7 +284,8 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
|
||||
#' @rdname predict.xgb.Booster
|
||||
#' @export
|
||||
predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE, ntreelimit = NULL,
|
||||
predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, reshape = FALSE, ...) {
|
||||
predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE,
|
||||
reshape = FALSE, ...) {
|
||||
|
||||
object <- xgb.Booster.complete(object, saveraw = FALSE)
|
||||
if (!inherits(newdata, "xgb.DMatrix"))
|
||||
@@ -285,7 +301,8 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
||||
if (ntreelimit < 0)
|
||||
stop("ntreelimit cannot be negative")
|
||||
|
||||
option <- 0L + 1L * as.logical(outputmargin) + 2L * as.logical(predleaf) + 4L * as.logical(predcontrib) + 8L * as.logical(approxcontrib)
|
||||
option <- 0L + 1L * as.logical(outputmargin) + 2L * as.logical(predleaf) + 4L * as.logical(predcontrib) +
|
||||
8L * as.logical(approxcontrib) + 16L * as.logical(predinteraction)
|
||||
|
||||
ret <- .Call(XGBoosterPredict_R, object$handle, newdata, option[1], as.integer(ntreelimit))
|
||||
|
||||
@@ -305,17 +322,28 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
||||
} else if (predcontrib) {
|
||||
n_col1 <- ncol(newdata) + 1
|
||||
n_group <- npred_per_case / n_col1
|
||||
dnames <- if (!is.null(colnames(newdata))) list(NULL, c(colnames(newdata), "BIAS")) else NULL
|
||||
cnames <- if (!is.null(colnames(newdata))) c(colnames(newdata), "BIAS") else NULL
|
||||
ret <- if (n_ret == n_row) {
|
||||
matrix(ret, ncol = 1, dimnames = dnames)
|
||||
matrix(ret, ncol = 1, dimnames = list(NULL, cnames))
|
||||
} else if (n_group == 1) {
|
||||
matrix(ret, nrow = n_row, byrow = TRUE, dimnames = dnames)
|
||||
matrix(ret, nrow = n_row, byrow = TRUE, dimnames = list(NULL, cnames))
|
||||
} else {
|
||||
grp_mask <- rep(seq_len(n_col1), n_row) +
|
||||
rep((seq_len(n_row) - 1) * n_col1 * n_group, each = n_col1)
|
||||
lapply(seq_len(n_group), function(g) {
|
||||
matrix(ret[grp_mask + n_col1 * (g - 1)], nrow = n_row, byrow = TRUE, dimnames = dnames)
|
||||
})
|
||||
arr <- array(ret, c(n_col1, n_group, n_row),
|
||||
dimnames = list(cnames, NULL, NULL)) %>% aperm(c(2,3,1)) # [group, row, col]
|
||||
lapply(seq_len(n_group), function(g) arr[g,,])
|
||||
}
|
||||
} else if (predinteraction) {
|
||||
n_col1 <- ncol(newdata) + 1
|
||||
n_group <- npred_per_case / n_col1^2
|
||||
cnames <- if (!is.null(colnames(newdata))) c(colnames(newdata), "BIAS") else NULL
|
||||
ret <- if (n_ret == n_row) {
|
||||
matrix(ret, ncol = 1, dimnames = list(NULL, cnames))
|
||||
} else if (n_group == 1) {
|
||||
array(ret, c(n_col1, n_col1, n_row), dimnames = list(cnames, cnames, NULL)) %>% aperm(c(3,1,2))
|
||||
} else {
|
||||
arr <- array(ret, c(n_col1, n_col1, n_group, n_row),
|
||||
dimnames = list(cnames, cnames, NULL, NULL)) %>% aperm(c(3,4,1,2)) # [group, row, col1, col2]
|
||||
lapply(seq_len(n_group), function(g) arr[g,,,])
|
||||
}
|
||||
} else if (reshape && npred_per_case > 1) {
|
||||
ret <- matrix(ret, nrow = n_row, byrow = TRUE)
|
||||
|
||||
Reference in New Issue
Block a user