[R] Fix single sample prediction. (#7524)
This commit is contained in:
@@ -157,3 +157,28 @@ test_that("multiclass feature interactions work", {
|
||||
# sums WRT columns must be close to feature contributions
|
||||
expect_lt(max(abs(apply(intr, c(1, 2, 3), sum) - aperm(cont, c(3, 1, 2)))), 0.00001)
|
||||
})
|
||||
|
||||
|
||||
test_that("SHAP single sample works", {
|
||||
train <- agaricus.train
|
||||
test <- agaricus.test
|
||||
booster <- xgboost(
|
||||
data = train$data,
|
||||
label = train$label,
|
||||
max_depth = 2,
|
||||
nrounds = 4,
|
||||
objective = "binary:logistic",
|
||||
)
|
||||
|
||||
predt <- predict(
|
||||
booster,
|
||||
newdata = train$data[1, , drop = FALSE], predcontrib = TRUE
|
||||
)
|
||||
expect_equal(dim(predt), c(1, dim(train$data)[2] + 1))
|
||||
|
||||
predt <- predict(
|
||||
booster,
|
||||
newdata = train$data[1, , drop = FALSE], predinteraction = TRUE
|
||||
)
|
||||
expect_equal(dim(predt), c(1, dim(train$data)[2] + 1, dim(train$data)[2] + 1))
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user