parent
3e2d7519a6
commit
328d1e18db
@ -435,7 +435,8 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
||||
lapply(seq_len(n_groups), function(g) arr[g, , ])
|
||||
} else {
|
||||
## remove the first axis (group)
|
||||
as.matrix(arr[1, , ])
|
||||
dn <- dimnames(arr)
|
||||
matrix(arr[1, , ], nrow = dim(arr)[2], ncol = dim(arr)[3], dimnames = c(dn[2], dn[3]))
|
||||
}
|
||||
} else if (predinteraction) {
|
||||
## Predict interaction
|
||||
@ -447,7 +448,8 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
||||
lapply(seq_len(n_groups), function(g) arr[g, , , ])
|
||||
} else {
|
||||
## remove the first axis (group)
|
||||
arr[1, , , ]
|
||||
arr <- arr[1, , , , drop = FALSE]
|
||||
array(arr, dim = dim(arr)[2:4], dimnames(arr)[2:4])
|
||||
}
|
||||
} else {
|
||||
## Normal prediction
|
||||
|
||||
@ -157,3 +157,28 @@ test_that("multiclass feature interactions work", {
|
||||
# sums WRT columns must be close to feature contributions
|
||||
expect_lt(max(abs(apply(intr, c(1, 2, 3), sum) - aperm(cont, c(3, 1, 2)))), 0.00001)
|
||||
})
|
||||
|
||||
|
||||
test_that("SHAP single sample works", {
|
||||
train <- agaricus.train
|
||||
test <- agaricus.test
|
||||
booster <- xgboost(
|
||||
data = train$data,
|
||||
label = train$label,
|
||||
max_depth = 2,
|
||||
nrounds = 4,
|
||||
objective = "binary:logistic",
|
||||
)
|
||||
|
||||
predt <- predict(
|
||||
booster,
|
||||
newdata = train$data[1, , drop = FALSE], predcontrib = TRUE
|
||||
)
|
||||
expect_equal(dim(predt), c(1, dim(train$data)[2] + 1))
|
||||
|
||||
predt <- predict(
|
||||
booster,
|
||||
newdata = train$data[1, , drop = FALSE], predinteraction = TRUE
|
||||
)
|
||||
expect_equal(dim(predt), c(1, dim(train$data)[2] + 1, dim(train$data)[2] + 1))
|
||||
})
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user